Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions clip_benchmark/metrics/image_caption_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def evaluate(model, dataloader, tokenizer, device, amp=True):

dict of accuracy metrics
"""
autocast = torch.cuda.amp.autocast if amp else suppress
image_score = []
text_score = []
score = []
Expand All @@ -52,7 +51,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True):
# tokenize all texts in the batch
batch_texts_tok_ = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(device)
# compute the embedding of images and texts
with torch.no_grad(), autocast():
with torch.no_grad(), torch.autocast(device, enabled=amp):
batch_images_emb = F.normalize(model.encode_image(batch_images_), dim=-1).view(B, nim, -1)
batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok_), dim=-1).view(B, nt, -1)
gt = torch.arange(min(nim, nt)).to(device)
Expand Down
25 changes: 12 additions & 13 deletions clip_benchmark/metrics/linear_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __getitem__(self, i):
return self.features[i], self.targets[i]


def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autocast, device, seed):
def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, amp, device, seed):
torch.manual_seed(seed)
model = torch.nn.Linear(input_shape, output_shape)
devices = [x for x in range(torch.cuda.device_count())]
Expand All @@ -81,7 +81,7 @@ def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autoc
scheduler(step)

optimizer.zero_grad()
with autocast():
with torch.autocast(device, enabled=amp):
pred = model(x)
loss = criterion(pred, y)

Expand All @@ -107,14 +107,14 @@ def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autoc
return model


def infer(model, dataloader, autocast, device):
def infer(model, dataloader, amp, device):
true, pred = [], []
with torch.no_grad():
for x, y in tqdm(dataloader):
x = x.to(device)
y = y.to(device)

with autocast():
with torch.autocast(device, enabled=amp):
logits = model(x)

pred.append(logits.cpu())
Expand All @@ -125,12 +125,12 @@ def infer(model, dataloader, autocast, device):
return logits, target


def find_peak(wd_list, idxs, train_loader, val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed):
def find_peak(wd_list, idxs, train_loader, val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed):
best_wd_idx, max_acc = 0, 0
for idx in idxs:
weight_decay = wd_list[idx]
model = train(train_loader, input_shape, output_shape, weight_decay, lr, epochs, autocast, device, seed)
logits, target = infer(model, val_loader, autocast, device)
model = train(train_loader, input_shape, output_shape, weight_decay, lr, epochs, amp, device, seed)
logits, target = infer(model, val_loader, amp, device)
acc1, = accuracy(logits.float(), target.float(), topk=(1,))
if verbose:
print(f"Valid accuracy with weight_decay {weight_decay}: {acc1}")
Expand All @@ -150,7 +150,6 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
os.mkdir(feature_dir)

featurizer = Featurizer(model, normalize).cuda()
autocast = torch.cuda.amp.autocast if amp else suppress
if not os.path.exists(os.path.join(feature_dir, 'targets_train.pt')):
# now we have to cache the features
devices = [x for x in range(torch.cuda.device_count())]
Expand All @@ -168,7 +167,7 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
for images, target in tqdm(loader):
images = images.to(device)

with autocast():
with torch.autocast(device, enabled=amp):
feature = featurizer(images)

features.append(feature.cpu())
Expand Down Expand Up @@ -270,20 +269,20 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
wd_list = np.logspace(-6, 2, num=97).tolist()
wd_list_init = np.logspace(-6, 2, num=7).tolist()
wd_init_idx = [i for i, val in enumerate(wd_list) if val in wd_list_init]
peak_idx = find_peak(wd_list, wd_init_idx, feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed)
peak_idx = find_peak(wd_list, wd_init_idx, feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed)
step_span = 8
while step_span > 0:
left, right = max(peak_idx - step_span, 0), min(peak_idx + step_span, len(wd_list)-1)
peak_idx = find_peak(wd_list, [left, peak_idx, right], feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, autocast, device, verbose, seed)
peak_idx = find_peak(wd_list, [left, peak_idx, right], feature_train_loader, feature_val_loader, input_shape, output_shape, lr, epochs, amp, device, verbose, seed)
step_span //= 2
best_wd = wd_list[peak_idx]
train_loader = feature_train_val_loader
else:
best_wd = 0
train_loader = feature_train_loader

final_model = train(train_loader, input_shape, output_shape, best_wd, lr, epochs, autocast, device, seed)
logits, target = infer(final_model, feature_test_loader, autocast, device)
final_model = train(train_loader, input_shape, output_shape, best_wd, lr, epochs, amp, device, seed)
logits, target = infer(final_model, feature_test_loader, amp, device)
pred = logits.argmax(axis=1)

# measure accuracy
Expand Down
7 changes: 3 additions & 4 deletions clip_benchmark/metrics/zeroshot_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sklearn.metrics import classification_report, balanced_accuracy_score



def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=True):
"""
This function returns zero-shot vectors for each class in order
Expand All @@ -36,8 +37,7 @@ def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=Tr
torch.Tensor of shape (N,C) where N is the number
of templates, and C is the number of classes.
"""
autocast = torch.cuda.amp.autocast if amp else suppress
with torch.no_grad(), autocast():
with torch.no_grad(), torch.autocast(device, enabled=amp):
zeroshot_weights = []
for classname in tqdm(classnames):
if type(templates) == dict:
Expand Down Expand Up @@ -100,7 +100,6 @@ def run_classification(model, classifier, dataloader, device, amp=True):
- pred (N, C) are the logits
- true (N,) are the actual classes
"""
autocast = torch.cuda.amp.autocast if amp else suppress
pred = []
true = []
nb = 0
Expand All @@ -109,7 +108,7 @@ def run_classification(model, classifier, dataloader, device, amp=True):
images = images.to(device)
target = target.to(device)

with autocast():
with torch.autocast(device, enabled=amp):
# predict
image_features = model.encode_image(images)
image_features = F.normalize(image_features, dim=-1)
Expand Down
5 changes: 2 additions & 3 deletions clip_benchmark/metrics/zeroshot_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5])
batch_texts_emb_list = []
# for each text, we collect the corresponding image index, as each image can have multiple corresponding texts
texts_image_index = []
dataloader = dataloader_with_indices(dataloader)
autocast = torch.cuda.amp.autocast if amp else suppress
dataloader = dataloader_with_indices(dataloader)
for batch_images, batch_texts, inds in tqdm(dataloader):
batch_images = batch_images.to(device)
# tokenize all texts in the batch
Expand All @@ -49,7 +48,7 @@ def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5])
batch_texts_image_index = [ind for ind, texts in zip(inds, batch_texts) for text in texts]

# compute the embedding of images and texts
with torch.no_grad(), autocast():
with torch.no_grad(), torch.autocast(device, enabled=amp):
batch_images_emb = F.normalize(model.encode_image(batch_images), dim=-1)
batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok), dim=-1)

Expand Down