diff --git a/rdt/train.py b/rdt/train.py index 1356bc5..f884bcc 100644 --- a/rdt/train.py +++ b/rdt/train.py @@ -448,7 +448,7 @@ def save_model_hook(models, weights, output_dir): accelerator.print(f"Resuming from checkpoint {path}") try: accelerator.load_state(os.path.join(args.output_dir, path)) # load_module_strict=False - except: + except Exception: # load deepspeed's state_dict logger.info("Resuming training state failed. Attempting to only load from model checkpoint.") checkpoint = torch.load(os.path.join(args.output_dir, path, "pytorch_model", "mp_rank_00_model_states.pt"))