From ca5e5d388eeecab14b89f3717b8811a4f1e39a41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mehmet=20Efe=20Ak=C3=A7a?= <13402668+mmtftr@users.noreply.github.com> Date: Wed, 18 Dec 2024 21:45:37 +0300 Subject: [PATCH] fix: check if model checkpoint exists before loading --- bin/main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bin/main.py b/bin/main.py index e3d0941..3206a09 100644 --- a/bin/main.py +++ b/bin/main.py @@ -516,12 +516,13 @@ def train_validate_test(gpu, args): # Log the number of parameters by layer count_parameters_by_layer(model.module) - # Load the model weights if --load-model argument is provided (using the DATA_PATH directory as the root) + # Load the model weights if the model_file exists (using the DATA_PATH directory as the root) # TODO: Process model loading in the get_setup function - if args.model_file: + checkpoint_path = os.path.join(config["DATA_PATH"], args.model_file), + if args.model_file and os.path.exists(checkpoint_path): load_model( trainer=Trainer, - checkpoint_path=os.path.join(config["DATA_PATH"], args.model_file), + checkpoint_path=checkpoint_path, rank=rank, from_checkpoint=args.from_checkpoint, )