From 43576b0a9f47d3d40a1004c1a0a887dba66e7746 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Thu, 12 Jun 2025 08:15:00 +0200 Subject: [PATCH] fix issue #125 --- clip_benchmark/metrics/linear_probe.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/clip_benchmark/metrics/linear_probe.py b/clip_benchmark/metrics/linear_probe.py index ead75f0..bfc25c4 100644 --- a/clip_benchmark/metrics/linear_probe.py +++ b/clip_benchmark/metrics/linear_probe.py @@ -277,7 +277,13 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor peak_idx = find_peak(wd_list, [left, peak_idx, right], feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed) step_span //= 2 best_wd = wd_list[peak_idx] - train_loader = feature_train_val_loader + if fewshot_k < 0: + # if we are doing full training, we use the full training set (train+val) + train_loader = feature_train_val_loader + else: + # if we are doing few-shot learning, we use the few-shot training set only + # as adding the validation set will train on more data than intended + train_loader = feature_train_loader else: best_wd = 0 train_loader = feature_train_loader