diff --git a/deepethogram/losses.py b/deepethogram/losses.py index 7adb16c..7345d14 100644 --- a/deepethogram/losses.py +++ b/deepethogram/losses.py @@ -191,7 +191,7 @@ def get_regularization_loss(cfg: DictConfig, model): pretrained_dir = cfg.project.pretrained_path assert os.path.isdir(pretrained_dir) weights = projects.get_weights_from_model_path(pretrained_dir) - pretrained_file = weights[cfg.run.model][cfg[cfg.run.model].arch] + pretrained_file = weights[cfg.run.model].get(cfg[cfg.run.model].arch, []) if len(pretrained_file) == 0: log.warning('No pretrained file found. Regularization: L2. alpha={}'.format(