Skip to content

[BUG] load_model_trainer_states_from_checkpoint() function is broken #806

@Legoclones

Description

@Legoclones

Bug description

The transformers4rec.torch.trainer.load_model_trainer_states_from_checkpoint() function can only load t4rec_model_class.pkl, all the other loads are broken. The utils.serialization.load() function is called 4 times in the function, but only one of them successfully loads, the other throw errors. I believe this is because load() is expecting an _io.BufferedReader object (resulting from open()) instead of a path.

Correctly implemented:

model = load(
open(os.path.join(checkpoint_path, "t4rec_model_class.pkl"), "rb")
)

Incorrectly implemented:

load(os.path.join(checkpoint_path, "pytorch_model.bin"), torch_load=True)

checkpoint_rng_state = load(rng_file, torch_load=True)

load(os.path.join(checkpoint_path, "scaler.pt"), torch_load=True)

Steps/Code to reproduce bug

Run the following code with a Model checkpoint at /checkpoint:

from transformers4rec import torch as tr

trainer = tr.trainer.Trainer(model=tr.model.base.Model(), args=None)
trainer.load_model_trainer_states_from_checkpoint('/checkpoint')

Expected behavior

The function call shouldn't be erroring.

Environment details

  • Transformers4Rec version: 23.12.0
  • Platform: Linux
  • Python version: 3.10.12
  • Huggingface Transformers version: 4.27.1
  • PyTorch version (GPU?): 2.1.2 (no)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions