diff --git a/clip_benchmark/metrics/image_caption_selection.py b/clip_benchmark/metrics/image_caption_selection.py index 85b2eb3..c84c467 100644 --- a/clip_benchmark/metrics/image_caption_selection.py +++ b/clip_benchmark/metrics/image_caption_selection.py @@ -34,7 +34,6 @@ def evaluate(model, dataloader, tokenizer, device, amp=True): dict of accuracy metrics """ - autocast = torch.cuda.amp.autocast if amp else suppress image_score = [] text_score = [] score = [] @@ -52,7 +51,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True): # tokenize all texts in the batch batch_texts_tok_ = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(device) # compute the embedding of images and texts - with torch.no_grad(), autocast(): + with torch.no_grad(), torch.autocast(device, enabled=amp): batch_images_emb = F.normalize(model.encode_image(batch_images_), dim=-1).view(B, nim, -1) batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok_), dim=-1).view(B, nt, -1) gt = torch.arange(min(nim, nt)).to(device) diff --git a/clip_benchmark/metrics/linear_probe.py b/clip_benchmark/metrics/linear_probe.py index ead75f0..6ac7128 100644 --- a/clip_benchmark/metrics/linear_probe.py +++ b/clip_benchmark/metrics/linear_probe.py @@ -56,7 +56,7 @@ def __getitem__(self, i): return self.features[i], self.targets[i] -def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autocast, device, seed): +def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, amp, device, seed): torch.manual_seed(seed) model = torch.nn.Linear(input_shape, output_shape) devices = [x for x in range(torch.cuda.device_count())] @@ -81,7 +81,7 @@ def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autoc scheduler(step) optimizer.zero_grad() - with autocast(): + with torch.autocast(device, enabled=amp): pred = model(x) loss = criterion(pred, y) @@ -107,14 +107,14 @@ def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autoc return model -def infer(model, dataloader, autocast, device): +def infer(model, dataloader, amp, device): true, pred = [], [] with torch.no_grad(): for x, y in tqdm(dataloader): x = x.to(device) y = y.to(device) - with autocast(): + with torch.autocast(device, enabled=amp): logits = model(x) pred.append(logits.cpu()) @@ -125,12 +125,12 @@ def infer(model, dataloader, autocast, device): return logits, target -def find_peak(wd_list, idxs, train_loader, val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed): +def find_peak(wd_list, idxs, train_loader, val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed): best_wd_idx, max_acc = 0, 0 for idx in idxs: weight_decay = wd_list[idx] - model = train(train_loader, input_shape, output_shape, weight_decay, lr, epochs, autocast, device, seed) - logits, target = infer(model, val_loader, autocast, device) + model = train(train_loader, input_shape, output_shape, weight_decay, lr, epochs, amp, device, seed) + logits, target = infer(model, val_loader, amp, device) acc1, = accuracy(logits.float(), target.float(), topk=(1,)) if verbose: print(f"Valid accuracy with weight_decay {weight_decay}: {acc1}") @@ -150,7 +150,6 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor os.mkdir(feature_dir) featurizer = Featurizer(model, normalize).cuda() - autocast = torch.cuda.amp.autocast if amp else suppress if not os.path.exists(os.path.join(feature_dir, 'targets_train.pt')): # now we have to cache the features devices = [x for x in range(torch.cuda.device_count())] @@ -168,7 +167,7 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor for images, target in tqdm(loader): images = images.to(device) - with autocast(): + with torch.autocast(device, enabled=amp): feature = featurizer(images) features.append(feature.cpu()) @@ -270,11 +269,11 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor wd_list = np.logspace(-6, 2, num=97).tolist() wd_list_init = np.logspace(-6, 2, num=7).tolist() wd_init_idx = [i for i, val in enumerate(wd_list) if val in wd_list_init] - peak_idx = find_peak(wd_list, wd_init_idx, feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed) + peak_idx = find_peak(wd_list, wd_init_idx, feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed) step_span = 8 while step_span > 0: left, right = max(peak_idx - step_span, 0), min(peak_idx + step_span, len(wd_list)-1) - peak_idx = find_peak(wd_list, [left, peak_idx, right], feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed) + peak_idx = find_peak(wd_list, [left, peak_idx, right], feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed) step_span //= 2 best_wd = wd_list[peak_idx] train_loader = feature_train_val_loader @@ -282,8 +281,8 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor best_wd = 0 train_loader = feature_train_loader - final_model = train(train_loader, input_shape, output_shape, best_wd, lr, epochs, autocast, device, seed) - logits, target = infer(final_model, feature_test_loader, autocast, device) + final_model = train(train_loader, input_shape, output_shape, best_wd, lr, epochs, amp, device, seed) + logits, target = infer(final_model, feature_test_loader, amp, device) pred = logits.argmax(axis=1) # measure accuracy diff --git a/clip_benchmark/metrics/zeroshot_classification.py b/clip_benchmark/metrics/zeroshot_classification.py index e3962aa..9c70e1e 100644 --- a/clip_benchmark/metrics/zeroshot_classification.py +++ b/clip_benchmark/metrics/zeroshot_classification.py @@ -12,6 +12,7 @@ from sklearn.metrics import classification_report, balanced_accuracy_score + def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=True): """ This function returns zero-shot vectors for each class in order @@ -36,8 +37,7 @@ def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=Tr torch.Tensor of shape (N,C) where N is the number of templates, and C is the number of classes. """ - autocast = torch.cuda.amp.autocast if amp else suppress - with torch.no_grad(), autocast(): + with torch.no_grad(), torch.autocast(device, enabled=amp): zeroshot_weights = [] for classname in tqdm(classnames): if type(templates) == dict: @@ -100,7 +100,6 @@ def run_classification(model, classifier, dataloader, device, amp=True): - pred (N, C) are the logits - true (N,) are the actual classes """ - autocast = torch.cuda.amp.autocast if amp else suppress pred = [] true = [] nb = 0 @@ -109,7 +108,7 @@ def run_classification(model, classifier, dataloader, device, amp=True): images = images.to(device) target = target.to(device) - with autocast(): + with torch.autocast(device, enabled=amp): # predict image_features = model.encode_image(images) image_features = F.normalize(image_features, dim=-1) diff --git a/clip_benchmark/metrics/zeroshot_retrieval.py b/clip_benchmark/metrics/zeroshot_retrieval.py index 3f1426a..1b5354d 100644 --- a/clip_benchmark/metrics/zeroshot_retrieval.py +++ b/clip_benchmark/metrics/zeroshot_retrieval.py @@ -39,8 +39,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5]) batch_texts_emb_list = [] # for each text, we collect the corresponding image index, as each image can have multiple corresponding texts texts_image_index = [] - dataloader = dataloader_with_indices(dataloader) - autocast = torch.cuda.amp.autocast if amp else suppress + dataloader = dataloader_with_indices(dataloader) for batch_images, batch_texts, inds in tqdm(dataloader): batch_images = batch_images.to(device) # tokenize all texts in the batch @@ -49,7 +48,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5]) batch_texts_image_index = [ind for ind, texts in zip(inds, batch_texts) for text in texts] # compute the embedding of images and texts - with torch.no_grad(), autocast(): + with torch.no_grad(), torch.autocast(device, enabled=amp): batch_images_emb = F.normalize(model.encode_image(batch_images), dim=-1) batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok), dim=-1)