diff --git a/clip_benchmark/metrics/linear_probe.py b/clip_benchmark/metrics/linear_probe.py index ead75f0..bfc25c4 100644 --- a/clip_benchmark/metrics/linear_probe.py +++ b/clip_benchmark/metrics/linear_probe.py @@ -277,7 +277,13 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor peak_idx = find_peak(wd_list, [left, peak_idx, right], feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed) step_span //= 2 best_wd = wd_list[peak_idx] - train_loader = feature_train_val_loader + if fewshot_k < 0: + # if we are doing full training, we use the full training set (train+val) + train_loader = feature_train_val_loader + else: + # if we are doing few-shot learning, we use the few-shot training set only + # as adding the validation set will train on more data than intended + train_loader = feature_train_loader else: best_wd = 0 train_loader = feature_train_loader