-
Notifications
You must be signed in to change notification settings - Fork 157
Open
Labels
Description
Bug description
The transformers4rec.torch.trainer.load_model_trainer_states_from_checkpoint() function can only load t4rec_model_class.pkl, all the other loads are broken. The utils.serialization.load() function is called 4 times in the function, but only one of them successfully loads, the other throw errors. I believe this is because load() is expecting an _io.BufferedReader object (resulting from open()) instead of a path.
Correctly implemented:
Transformers4Rec/transformers4rec/torch/trainer.py
Lines 744 to 746 in 5d59d14
| model = load( | |
| open(os.path.join(checkpoint_path, "t4rec_model_class.pkl"), "rb") | |
| ) |
Incorrectly implemented:
| load(os.path.join(checkpoint_path, "pytorch_model.bin"), torch_load=True) |
| checkpoint_rng_state = load(rng_file, torch_load=True) |
| load(os.path.join(checkpoint_path, "scaler.pt"), torch_load=True) |
Steps/Code to reproduce bug
Run the following code with a Model checkpoint at /checkpoint:
from transformers4rec import torch as tr
trainer = tr.trainer.Trainer(model=tr.model.base.Model(), args=None)
trainer.load_model_trainer_states_from_checkpoint('/checkpoint')Expected behavior
The function call shouldn't be erroring.
Environment details
- Transformers4Rec version: 23.12.0
- Platform: Linux
- Python version: 3.10.12
- Huggingface Transformers version: 4.27.1
- PyTorch version (GPU?): 2.1.2 (no)