diff --git a/official_python_implementation/src/utils.py b/official_python_implementation/src/utils.py index 6a2817b..110a46b 100644 --- a/official_python_implementation/src/utils.py +++ b/official_python_implementation/src/utils.py @@ -45,7 +45,7 @@ def get_model_and_optimizer(opt): main_model_params = [ p for p in model.parameters() - if all(p is not x for x in model.classification_loss.parameters()) + if all(p is not x for x in model.linear_classifier.parameters()) ] optimizer = torch.optim.SGD( [ @@ -56,7 +56,7 @@ def get_model_and_optimizer(opt): "momentum": opt.training.momentum, }, { - "params": model.classification_loss.parameters(), + "params": model.linear_classifier.parameters(), "lr": opt.training.downstream_learning_rate, "weight_decay": opt.training.downstream_weight_decay, "momentum": opt.training.momentum,