diff --git a/bin/main.py b/bin/main.py index e3d0941..3206a09 100644 --- a/bin/main.py +++ b/bin/main.py @@ -516,12 +516,13 @@ def train_validate_test(gpu, args): # Log the number of parameters by layer count_parameters_by_layer(model.module) - # Load the model weights if --load-model argument is provided (using the DATA_PATH directory as the root) + # Load the model weights if the model_file exists (using the DATA_PATH directory as the root) # TODO: Process model loading in the get_setup function - if args.model_file: + checkpoint_path = os.path.join(config["DATA_PATH"], args.model_file), + if args.model_file and os.path.exists(checkpoint_path): load_model( trainer=Trainer, - checkpoint_path=os.path.join(config["DATA_PATH"], args.model_file), + checkpoint_path=checkpoint_path, rank=rank, from_checkpoint=args.from_checkpoint, )