diff --git a/.gitignore b/.gitignore index a9a721b..4bc9eeb 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,10 @@ *.ipynb_checkpoints *.json *.pth.tar + + +.idea +data +pretrained_models +*.tar +*.ipynb \ No newline at end of file diff --git a/README.md b/README.md index cb3837f..cedb9d8 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,12 @@ conda activate teran export PYTHONPATH=. ``` +2.1 Setup minimal python environment for CUDA 10.1 using conda: +``` +conda env create --file environment_min.yml +conda activate teran +export PYTHONPATH=. +``` ## Get the data 1. Download and extract the data folder, containing annotations, the splits by Karpathy et al. and ROUGEL - SPICE precomputed relevances for both COCO and Flickr30K datasets: diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..705e854 --- /dev/null +++ b/__init__.py @@ -0,0 +1 @@ +from .data import d \ No newline at end of file diff --git a/configs/teran_coco_MrSw_IR.yaml b/configs/teran_coco_MrSw_IR.yaml new file mode 100644 index 0000000..0c79a1f --- /dev/null +++ b/configs/teran_coco_MrSw_IR.yaml @@ -0,0 +1,63 @@ +dataset: + name: 'coco' + images-path: 'data/coco/images' # not needed if using pre-extracted bottom-up features + data: 'data' + restval: True + pre-extracted-features: False + +image-retrieval: + dataset: 'coco' # for now only coco support + split: 'test' # we can remove this in later versions + num_imgs: 5000 + batch_size: 100 # 100 takes ~10s; 1000 takes ~14s to encode the data (compute the TE outputs) + pre_extracted_img_features_root: 'data/coco/features_36' + create_query_batch: False + alignment_mode: 'MrSw' + use_precomputed_img_embeddings: False + pre_computed_img_embeddings_root: 'data/coco/pre_computed_embeddings' + +text-model: + name: 'bert' + pretrain: 'bert-base-uncased' + word-dim: 768 + extraction-hidden-layer: 6 + fine-tune: True + pre-extracted: False + layers: 0 + dropout: 0.1 + +image-model: + name: 'bottomup' + pre-extracted-features-root: 'data/coco/features_36' + transformer-layers: 4 + dropout: 0.1 + pos-encoding: 'concat-and-process' + crop-size: 224 # not used + fine-tune: False + feat-dim: 2048 + norm: True + +model: + name: 'teran' + embed-size: 1024 + text-aggregation: 'first' + image-aggregation: 'first' + layers: 2 + exclude-stopwords: False + shared-transformer: False + dropout: 0.1 + +training: + lr: 0.00001 # 0.000006 + grad-clip: 2.0 + max-violation: True + loss-type: 'alignment' + alignment-mode: 'MrSw' + measure: 'dot' + margin: 0.2 + bs: 40 + scheduler: 'steplr' + gamma: 0.1 + step-size: 20 + warmup: null + warmup-period: 1000 diff --git a/configs/teran_coco_MrSw_IR_PreComp.yaml b/configs/teran_coco_MrSw_IR_PreComp.yaml new file mode 100644 index 0000000..7be7a6e --- /dev/null +++ b/configs/teran_coco_MrSw_IR_PreComp.yaml @@ -0,0 +1,63 @@ +dataset: + name: 'coco' + images-path: 'data/coco/images' # not needed if using pre-extracted bottom-up features + data: 'data' + restval: True + pre-extracted-features: False + +image-retrieval: + dataset: 'coco' # for now only coco support + split: 'test' # we can remove this in later versions + num_imgs: 5000 + batch_size: 100 # 100 takes ~10s; 1000 takes ~14s to encode the data (compute the TE outputs) + pre_extracted_img_features_root: 'data/coco/features_36' + create_query_batch: False + alignment_mode: 'MrSw' + use_precomputed_img_embeddings: True + pre_computed_img_embeddings_root: 'data/coco/pre_computed_embeddings' + +text-model: + name: 'bert' + pretrain: 'bert-base-uncased' + word-dim: 768 + extraction-hidden-layer: 6 + fine-tune: True + pre-extracted: False + layers: 0 + dropout: 0.1 + +image-model: + name: 'bottomup' + pre-extracted-features-root: 'data/coco/features_36' + transformer-layers: 4 + dropout: 0.1 + pos-encoding: 'concat-and-process' + crop-size: 224 # not used + fine-tune: False + feat-dim: 2048 + norm: True + +model: + name: 'teran' + embed-size: 1024 + text-aggregation: 'first' + image-aggregation: 'first' + layers: 2 + exclude-stopwords: False + shared-transformer: False + dropout: 0.1 + +training: + lr: 0.00001 # 0.000006 + grad-clip: 2.0 + max-violation: True + loss-type: 'alignment' + alignment-mode: 'MrSw' + measure: 'dot' + margin: 0.2 + bs: 40 + scheduler: 'steplr' + gamma: 0.1 + step-size: 20 + warmup: null + warmup-period: 1000 diff --git a/configs/teran_inf_coco_MrSw.yaml b/configs/teran_inf_coco_MrSw.yaml new file mode 100644 index 0000000..d3458aa --- /dev/null +++ b/configs/teran_inf_coco_MrSw.yaml @@ -0,0 +1,59 @@ +dataset: + name: 'coco' + images-path: 'data/coco/images' # not needed if using pre-extracted bottom-up features + data: 'data' + restval: True + pre-extracted-features: False + +text-model: + name: 'bert' + pretrain: 'bert-base-uncased' + word-dim: 768 + extraction-hidden-layer: 6 + fine-tune: True + pre-extracted: False + layers: 0 + dropout: 0.1 + +#text-model: +# name: 'gru' +# word-dim: 300 +# fine-tune: True +# pre-extracted: False +# layers: 1 + +image-model: + name: 'bottomup' + pre-extracted-features-root: 'data/coco/features_36' + transformer-layers: 4 + dropout: 0.1 + pos-encoding: 'concat-and-process' + crop-size: 224 # not used + fine-tune: False + feat-dim: 2048 + norm: True + +model: + name: 'teran' + embed-size: 1024 + text-aggregation: 'first' + image-aggregation: 'first' + layers: 2 + exclude-stopwords: False + shared-transformer: False + dropout: 0.1 + +training: + lr: 0.00001 # 0.000006 + grad-clip: 2.0 + max-violation: True + loss-type: 'alignment' + alignment-mode: 'MrSw' + measure: 'dot' + margin: 0.2 + bs: 40 + scheduler: 'steplr' + gamma: 0.1 + step-size: 20 + warmup: null + warmup-period: 1000 diff --git a/data.py b/data.py index 9b988fd..15f83b0 100644 --- a/data.py +++ b/data.py @@ -1,39 +1,42 @@ +import json as jsonmod +import os +import pickle +import time +from collections import OrderedDict +from multiprocessing import Pool + +import numpy as np import torch import torch.utils.data as data import torchvision.transforms as transforms -import os -import nltk +import tqdm from PIL import Image from pycocotools.coco import COCO -import numpy as np -import json as jsonmod -from collections.abc import Sequence -import shelve from transformers import BertTokenizer -import pickle -import tqdm from features import HuggingFaceTransformerExtractor def get_paths(config): + # noinspection PyIncorrectDocstring + # noinspection PyUnresolvedReferences """ - Returns paths to images and annotations for the given datasets. For MSCOCO - indices are also returned to control the data split being used. - The indices are extracted from the Karpathy et al. splits using this - snippet: - - >>> import json - >>> dataset=json.load(open('dataset_coco.json','r')) - >>> A=[] - >>> for i in range(len(D['images'])): - ... if D['images'][i]['split'] == 'val': - ... A+=D['images'][i]['sentids'][:5] - ... - - :param name: Dataset names - :param use_restval: If True, the the `restval` data is included in train. - """ + Returns paths to images and annotations for the given datasets. For MSCOCO + indices are also returned to control the data split being used. + The indices are extracted from the Karpathy et al. splits using this + snippet: + + >>> import json + >>> dataset=json.load(open('dataset_coco.json','r')) + >>> A=[] + >>> for i in range(len(D['images'])): + ... if D['images'][i]['split'] == 'val': + ... A+=D['images'][i]['sentids'][:5] + ... + + :param name: Dataset names + :param use_restval: If True, the the `restval` data is included in train. + """ name = config['dataset']['name'] annotations_path = os.path.join(config['dataset']['data'], name, 'annotations') use_restval = config['dataset']['restval'] @@ -64,7 +67,8 @@ def get_paths(config): ids['test'] = np.load(os.path.join(annotations_path, 'coco_test_ids.npy')) ids['trainrestval'] = ( ids['train'], - np.load(os.path.join(annotations_path, 'coco_restval_ids.npy'))) + np.load(os.path.join(annotations_path, 'coco_restval_ids.npy')) + ) if use_restval: roots['train'] = roots['trainrestval'] ids['train'] = ids['trainrestval'] @@ -82,33 +86,33 @@ def get_paths(config): class CocoDataset(data.Dataset): """COCO Custom Dataset compatible with torch.utils.data.DataLoader.""" - def __init__(self, root, json, transform=None, ids=None, get_images=True): + def __init__(self, imgs_root, captions_json, transform=None, coco_annotation_ids=None, get_images=True): """ Args: - root: image directory. - json: coco annotation file path. + imgs_root: image directory. + captions_json: coco annotation file path. transform: transformer for image. """ - self.root = root + self.root = imgs_root self.get_images = get_images # when using `restval`, two json files are needed - if isinstance(json, tuple): - self.coco = (COCO(json[0]), COCO(json[1])) + if isinstance(captions_json, tuple): + self.coco = (COCO(captions_json[0]), COCO(captions_json[1])) else: - self.coco = (COCO(json),) - self.root = (root,) + self.coco = (COCO(captions_json),) + self.root = (imgs_root,) # if ids provided by get_paths, use split-specific ids - if ids is None: - self.ids = list(self.coco.anns.keys()) + if coco_annotation_ids is None: + self.annotation_ids = list(self.coco[0].anns.keys()) else: - self.ids = ids + self.annotation_ids = coco_annotation_ids # if `restval` data is to be used, record the break point for ids - if isinstance(self.ids, tuple): - self.bp = len(self.ids[0]) - self.ids = list(self.ids[0]) + list(self.ids[1]) + if isinstance(self.annotation_ids, tuple): + self.bp = len(self.annotation_ids[0]) + self.annotation_ids = list(self.annotation_ids[0]) + list(self.annotation_ids[1]) else: - self.bp = len(self.ids) + self.bp = len(self.annotation_ids) self.transform = transform def __getitem__(self, index): @@ -123,17 +127,17 @@ def __getitem__(self, index): return image, target, index, img_id def get_raw_item(self, index, load_image=True): - if index < self.bp: + if index < self.bp: # bp -> breakpoint to stop after N samples coco = self.coco[0] root = self.root[0] else: coco = self.coco[1] root = self.root[1] - ann_id = self.ids[index] + ann_id = self.annotation_ids[index] caption = coco.anns[ann_id]['caption'] img_id = coco.anns[ann_id]['image_id'] - img = coco.imgs[img_id] - img_size = np.array([img['width'], img['height']]) + img_metadata = coco.imgs[img_id] + img_size = np.array([img_metadata['width'], img_metadata['height']]) if load_image: path = coco.loadImgs(img_id)[0]['file_name'] image = Image.open(os.path.join(root, path)).convert('RGB') @@ -143,18 +147,151 @@ def get_raw_item(self, index, load_image=True): return root, caption, img_id, None, None, img_size def __len__(self): - return len(self.ids) + return len(self.annotation_ids) + + +class CocoImageRetrievalDatasetBase: + def __init__(self, captions_json, coco_annotation_ids, num_imgs): + self.num_imgs = num_imgs + + self.coco = COCO(captions_json) + self.anno_ids = coco_annotation_ids + + def get_image_metadata(self, idx): + next_img_idx = idx * 5 # in the coco dataset there are 5 captions for every image + ann_id = self.anno_ids[next_img_idx] + coco_img_id = self.coco.anns[ann_id]['image_id'] + img_metadata = self.coco.imgs[coco_img_id] + return coco_img_id, img_metadata + + +# This has to be outside any class so that it can be pickled for multiproc +def load_img_emb(args): + # just return the query and the img embedding + idx, file_name = args + npz = np.load(file_name) + img_emd = npz.get('img_emb') + return idx, img_emd + + +class PreComputedCocoImageEmbeddingsDataset(CocoImageRetrievalDatasetBase): + """ + Custom COCO Dataset that uses pre-computed image embedding + """ + + def __init__(self, captions_json, coco_annotation_ids, num_imgs, config, num_workers=32): + CocoImageRetrievalDatasetBase.__init__(self, captions_json, coco_annotation_ids, num_imgs) + + pre_computed_img_embeddings_root = config['image-retrieval']['pre_computed_img_embeddings_root'] + self.pre_computed_img_embeddings_root = pre_computed_img_embeddings_root + self.num_workers = num_workers + + self.img_embs = self.__load_img_embs() + + def __load_img_embs(self): + start = time.time() + print('Parallel loading of pre-computed image embeddings started...') + file_names = list(map(lambda m: os.path.join(self.pre_computed_img_embeddings_root, m[1]['file_name'] + '.npz'), + [self.get_image_metadata(i) for i in range(self.num_imgs)])) + # parallel loading of all image embeddings + with Pool(self.num_workers) as pool: + res = pool.map(load_img_emb, enumerate(file_names)) + pool.join() + res = OrderedDict(res) + print(f'Time elapsed to load pre-computed image embeddings: {time.time() - start} seconds') + return res + + def __len__(self): + return self.num_imgs + + +class QueryEncoder: + def __init__(self, config, model): + self.vocab_type = str(config['text-model']['name']).lower() + if self.vocab_type == 'bert': + self.tokenizer = BertTokenizer.from_pretrained(config['text-model']['pretrain']) + elif self.vocab_type != 'bert': + raise ValueError("Currently only BERT Tokenizer is supported!") + + self.model = model + + def _get_query_pseudo_batch(self, query: str): + # tokenize and encode the query + query_token_ids = torch.LongTensor(self.tokenizer.encode(query)) + # create a pseudo batch suitable for TERAN + query_token_pseudo_batch = query_token_ids.unsqueeze(dim=0) + query_lengths = [len(query_token_ids)] + return query_token_pseudo_batch, query_lengths + + def compute_query_embedding(self, query): + # compute the query embedding + with torch.no_grad(): + start_query_batch = time.time() + query_token_pseudo_batch, query_lengths = self._get_query_pseudo_batch(query) + print(f'Time to get query pseudo batch: {time.time() - start_query_batch}') + + start_query_enc = time.time() + query_emb_aggr, query_emb, _ = self.model.forward_txt(query_token_pseudo_batch, query_lengths) + print(f'Time to compute query embedding: {time.time() - start_query_enc}') + + # store results as np arrays for further processing or persisting + query_feat_dim = query_emb.size(2) + query_embs = torch.zeros((1, query_lengths[0], query_feat_dim), requires_grad=False) + query_embs[0, :, :] = query_emb.cpu().permute(1, 0, 2) + + return query_embs, query_lengths + + +class PreComputedCocoFeaturesDataset(CocoImageRetrievalDatasetBase, data.Dataset): + """ + Custom COCO Dataset that uses only the images together with a user query. + Compatible with torch.utils.data.DataLoader. + """ + + def __init__(self, imgs_root, img_features_path, captions_json, coco_annotation_ids, query, num_imgs): + CocoImageRetrievalDatasetBase.__init__(self, captions_json, coco_annotation_ids, num_imgs) + + self.feats_data_path = os.path.join(img_features_path, 'bu_att') + self.box_data_path = os.path.join(img_features_path, 'bu_box') + self.imgs_root = imgs_root + self.query = query + + def __getitem__(self, idx): + """ + This function returns a tuple that is further passed to collate_fn + """ + img_id, img_metadata = self.get_image_metadata(idx) + img_size = np.array([img_metadata['width'], img_metadata['height']]) + + img_feat_path = os.path.join(self.feats_data_path, '{}.npz'.format(img_id)) + img_box_path = os.path.join(self.box_data_path, '{}.npy'.format(img_id)) + + img_feat = np.load(img_feat_path)['feat'] + img_feat_box = np.load(img_box_path) + + # normalize box + img_feat_box = img_feat_box / np.tile(img_size, 2) + + img_feat = torch.Tensor(img_feat) + img_feat_box = torch.Tensor(img_feat_box) + + # we always return the query here since we want to compute the similarity of each image with the query + # this output is the input of the CollateFn + return img_feat, img_feat_box, img_id, self.query, idx + + def __len__(self): + return self.num_imgs class BottomUpFeaturesDataset: - def __init__(self, root, json, features_path, split, ids=None, **kwargs): + def __init__(self, imgs_root, captions_json, features_path, split, ids=None, **kwargs): # which dataset? - r = root[0] if type(root) == tuple else root + r = imgs_root[0] if type(imgs_root) == tuple else imgs_root r = r.lower() if 'coco' in r: - self.underlying_dataset = CocoDataset(root, json, ids=ids) + self.underlying_dataset = CocoDataset(imgs_root, captions_json, coco_annotation_ids=ids) elif 'f30k' in r or 'flickr30k' in r: - self.underlying_dataset = FlickrDataset(root, json, split) + self.underlying_dataset = FlickrDataset(imgs_root, captions_json, split) # data_path = config['image-model']['data-path'] self.feats_data_path = os.path.join(features_path, 'bu_att') @@ -191,7 +328,7 @@ def __getitem__(self, index): else: target = caption # image = (img_feat, img_boxes) - return img_feat, img_boxes, target, index, img_id + return img_feat, img_boxes, target, index, img_id # target is the actual caption sentence def __len__(self): return len(self.underlying_dataset) @@ -256,12 +393,119 @@ def get_raw_item(self, index, load_image=True): else: return root, caption, img_id, None, None, img_size - - def __len__(self): return len(self.ids) +class InferenceCollate(object): + def __new__(cls, *args, **kwargs): + # we only need to compute this once so it gets stored in a static class variable + cls.query_token_ids = None + cls.query_length = None + cls.img_feat_length = None + cls.img_feat_dim = None + cls.bboxes_length = None + cls.bboxes_dim = None + + return super(InferenceCollate, cls).__new__(cls) + + def __init__(self, config, pre_compute_img_embs): + self.create_query_batch = bool(config['image-retrieval']['create_query_batch']) + self.pre_compute_img_embs = pre_compute_img_embs + self.vocab_type = str(config['text-model']['name']).lower() + if self.vocab_type == 'bert' and not pre_compute_img_embs: + self.tokenizer = BertTokenizer.from_pretrained(config['text-model']['pretrain']) + elif self.vocab_type != 'bert': + raise ValueError("Currently only BERT Tokenizer is supported!") + + @classmethod + def set_query_token_ids(cls, query_token_ids): + cls.query_token_ids = query_token_ids + cls.query_length = len(query_token_ids) + + @classmethod + def set_img_feat_length_and_dimension(cls, img_feat): + # +1 because the first region feature is reserved as CLS + cls.img_feat_length = img_feat.shape[0] + 1 + cls.img_feat_dim = img_feat.shape[1] + + @classmethod + def set_bboxes_length_and_dimension(cls, bbox): + # +1 because the first region feature is reserved as CLS + cls.bboxes_length = bbox.shape[0] + 1 + cls.bboxes_dim = bbox.shape[1] + + def __call__(self, data): + img_feats, img_feat_bboxes, img_ids, queries, dataset_indices = zip(*data) + """ + Build batch tensors from a list of (img_feats, img_feat_boxes, img_ids, queries, dataset_indices) tuples. + This data comes from the dataset + Args: + - img_feats: + - img_feat_bboxes: + - img_ids: + - queries: + - dataset_indices: + + Returns: + - img_feature_batch: batch of image features + - img_feat_bboxes_batch: batch of bounding boxes of the image features + - img_feat_length: length of the image features and bounding boxes (all of same size) + - query_token_ids: bert token ids of the tokenized query + - query_length: length of the query + - dataset_indices: indices of the elements of the datasets inside the batch. + """ + + # encode (tokenize) the query + if self.query_token_ids is None and not self.pre_compute_img_embs: + # we don't need to pad or truncate since we only have a single query + # TODO actually we don't even need the tokenizer twice so we could just use a local variable + query_token_ids = torch.LongTensor(self.tokenizer.encode(queries[0])) + self.set_query_token_ids(query_token_ids) + + # prepare image features + if self.img_feat_length is None: + self.set_img_feat_length_and_dimension(img_feats[0]) + + # prepare bounding boxes + if self.bboxes_length is None: + self.set_bboxes_length_and_dimension(img_feat_bboxes[0]) + + assert self.bboxes_length == self.img_feat_length + + # create the image feature batch + batch_size = len(img_feats) + img_feature_batch = torch.zeros(batch_size, self.img_feat_length, self.img_feat_dim) + for i, f in enumerate(img_feats): + # reserve the first token as CLS + img_feature_batch[i, 1:] = f + + # create the image features bounding boxes batch + img_feat_lengths = [self.img_feat_length for _ in range(batch_size)] + img_feat_bboxes_batch = torch.zeros(batch_size, self.bboxes_length, self.bboxes_dim) + for i, box in enumerate(img_feat_bboxes): + img_feat_bboxes_batch[i, 1:] = box + + if self.create_query_batch and not self.pre_compute_img_embs: + # create the full query batch of size B x |Q| + # since the token id is a scalar, the dim is 1 and whe don't need to add it to the batch + # for the BERT embeddings the ids have to be Long + query_token_ids_batch = torch.zeros(batch_size, self.query_length).long() + for i in range(len(queries)): + query_token_ids_batch[i] = self.query_token_ids + query_lengths = [self.query_length for _ in range(batch_size)] + elif not self.create_query_batch and not self.pre_compute_img_embs: + # create a pseudo query batch with only one element of size 1 x |Q| + query_token_ids_batch = self.query_token_ids.unsqueeze(dim=0) + query_lengths = [self.query_length] + else: # self.pre_compute_img_embs == True + # when pre-computing the image embeddings, we don't need (and have) information about the query + query_token_ids_batch = None + query_lengths = None + + return img_feature_batch, img_feat_bboxes_batch, img_feat_lengths, query_token_ids_batch, query_lengths, dataset_indices + + class Collate: def __init__(self, config): self.vocab_type = config['text-model']['name'] @@ -277,12 +521,12 @@ def __call__(self, data): Returns: images: torch tensor of shape (batch_size, 3, 256, 256). - targets: torch tensor of shape (batch_size, padded_length). + targets: torch tensor of shape (batch_size, padded_length). -> the textual tokens lengths: list; valid length for each padded caption. """ # Sort a data list by caption length # data.sort(key=lambda x: len(x[1]), reverse=True) - if len(data[0]) == 5: # TODO: find a better way to distinguish the two + if len(data[0]) == 5: # TODO: find a better way to distinguish the two images, boxes, captions, ids, img_ids = zip(*data) elif len(data[0]) == 4: images, captions, ids, img_ids = zip(*data) @@ -298,12 +542,15 @@ def __call__(self, data): else: if self.vocab_type == 'bert': cap_lengths = [len(self.tokenizer.tokenize(c)) + 2 for c in - captions] # + 2 in order to account for begin and end tokens + captions] # + 2 in order to account for begin and end tokens max_len = max(cap_lengths) - captions_ids = [torch.LongTensor(self.tokenizer.encode(c, max_length=max_len, pad_to_max_length=True)) - for c in captions] + captions_token_ids = [torch.LongTensor(self.tokenizer.encode(c, + max_length=max_len, + padding='max_length', + truncation=True)) + for c in captions] - captions = captions_ids + captions = captions_token_ids # caption_ids are the token ids from bert tokenizer # Merge images (convert tuple of 3D tensor to 4D tensor) preextracted_images = not (images[0].shape[0] == 3) if not preextracted_images: @@ -339,40 +586,46 @@ def __call__(self, data): targets = torch.zeros(len(captions), max(cap_lengths)).long() for i, cap in enumerate(captions): end = cap_lengths[i] - targets[i, :end] = cap[:end] + targets[i, :end] = cap[:end] # caption token ids if not preextracted_images: return images, targets, None, cap_lengths, None, ids else: # features = features.permute(0, 2, 1) + # img_features -> from FRCNN >> B x 2048 + # targets -> padded caption token ids from BERT >> B x max_len(cap_lengths) or(queries) + # feat_lengths -> num of regions in the image (fixed to 36 + 1) >> B x 37 + # cap_lengths -> true length of the non-padded captions or queries >> B x 1 (list of len B) + # out_boxes -> spatial information of the region boxes >> B x 37 x 4 + # ids -> dataset indices wich are in this batch >> 1 x B (tuple of len B) return img_features, targets, feat_lengths, cap_lengths, out_boxes, ids -def get_loader_single(data_name, split, root, json, transform, preextracted_root=None, +def get_loader_single(data_name, split, imgs_root, captions_json, transform, pre_extracted_root=None, batch_size=100, shuffle=True, num_workers=2, ids=None, collate_fn=None, **kwargs): """Returns torch.utils.data.DataLoader for custom coco dataset.""" if 'coco' in data_name: - if preextracted_root is not None: - dataset = BottomUpFeaturesDataset(root=root, - json=json, - features_path=preextracted_root, split=split, + if pre_extracted_root is not None: + dataset = BottomUpFeaturesDataset(imgs_root=imgs_root, + captions_json=captions_json, + features_path=pre_extracted_root, split=split, ids=ids, **kwargs) else: # COCO custom dataset - dataset = CocoDataset(root=root, - json=json, - transform=transform, ids=ids) + dataset = CocoDataset(imgs_root=imgs_root, + captions_json=captions_json, + transform=transform, coco_annotation_ids=ids) elif 'f8k' in data_name or 'f30k' in data_name: - if preextracted_root is not None: - dataset = BottomUpFeaturesDataset(root=root, - json=json, - features_path=preextracted_root, split=split, + if pre_extracted_root is not None: + dataset = BottomUpFeaturesDataset(imgs_root=imgs_root, + captions_json=captions_json, + features_path=pre_extracted_root, split=split, ids=ids, **kwargs) else: - dataset = FlickrDataset(root=root, + dataset = FlickrDataset(root=imgs_root, split=split, - json=json, + json=captions_json, transform=transform) # Data loader @@ -385,7 +638,7 @@ def get_loader_single(data_name, split, root, json, transform, preextracted_root return data_loader -def get_transform(data_name, split_name, config): +def get_transform(data_name=None, split_name=None, config=None): normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) t_list = [] @@ -417,7 +670,7 @@ def get_loaders(config, workers, batch_size=None): roots['train']['img'], roots['train']['cap'], transform, ids=ids['train'], - preextracted_root=preextracted_root, + pre_extracted_root=preextracted_root, batch_size=batch_size, shuffle=True, num_workers=workers, collate_fn=collate_fn, config=config) @@ -427,7 +680,7 @@ def get_loaders(config, workers, batch_size=None): roots['val']['img'], roots['val']['cap'], transform, ids=ids['val'], - preextracted_root=preextracted_root, + pre_extracted_root=preextracted_root, batch_size=batch_size, shuffle=False, num_workers=workers, collate_fn=collate_fn, config=config) @@ -435,6 +688,50 @@ def get_loaders(config, workers, batch_size=None): return train_loader, val_loader +def get_coco_image_retrieval_data(config, query=None, num_workers=32, pre_compute_img_embs=False): + # get the directories that contain the coco json files and coco annotation ids (which we may not need, I think) + roots, coco_annotation_ids = get_paths(config) + + dataset_name = config['image-retrieval']['dataset'] + batch_size = config['image-retrieval']['batch_size'] + split_name = config['image-retrieval']['split'] + + imgs_root = roots[split_name]['img'] + + captions_json = roots[split_name]['cap'] + coco_annotation_ids = coco_annotation_ids[split_name] + num_imgs = config['image-retrieval']['num_imgs'] + pre_extracted_img_features_root = config['image-retrieval']['pre_extracted_img_features_root'] + + use_precomputed_img_embeddings = config['image-retrieval']['use_precomputed_img_embeddings'] + if use_precomputed_img_embeddings: + dataset = PreComputedCocoImageEmbeddingsDataset(captions_json=captions_json, + coco_annotation_ids=coco_annotation_ids, + num_imgs=num_imgs, + config=config, + num_workers=num_workers) + return dataset + + dataset = PreComputedCocoFeaturesDataset(imgs_root=imgs_root, + img_features_path=pre_extracted_img_features_root, + captions_json=captions_json, + coco_annotation_ids=coco_annotation_ids, + query=query, + num_imgs=num_imgs) + + # this creates the batches which get passed to the model (inside the query gets repeated or not based on the config) + collate_fn = InferenceCollate(config, pre_compute_img_embs) + + data_loader = data.DataLoader(dataset=dataset, + batch_size=batch_size, + shuffle=False, + pin_memory=True, + num_workers=num_workers, + collate_fn=collate_fn) + + return data_loader + + def get_test_loader(config, workers, split_name='test', batch_size=None): data_name = config['dataset']['name'] if batch_size is None: @@ -443,15 +740,15 @@ def get_test_loader(config, workers, split_name='test', batch_size=None): # Build Dataset Loader roots, ids = get_paths(config) - preextracted_root = config['image-model']['pre-extracted-features-root'] \ + pre_extracted_root = config['image-model']['pre-extracted-features-root'] \ if 'pre-extracted-features-root' in config['image-model'] else None transform = get_transform(data_name, split_name, config) test_loader = get_loader_single(data_name, split_name, - roots[split_name]['img'], - roots[split_name]['cap'], - transform, ids=ids[split_name], - preextracted_root=preextracted_root, + imgs_root=roots[split_name]['img'], + captions_json=roots[split_name]['cap'], + transform=transform, ids=ids[split_name], + pre_extracted_root=pre_extracted_root, batch_size=batch_size, shuffle=False, num_workers=workers, collate_fn=collate_fn, config=config) diff --git a/environment_min.yml b/environment_min.yml new file mode 100644 index 0000000..504ce86 --- /dev/null +++ b/environment_min.yml @@ -0,0 +1,98 @@ +name: teran +channels: + - pytorch + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _pytorch_select=0.1=cpu_0 + - arrow=0.17.0=py36h9f0ad1d_1 + - binaryornot=0.4.4=py_1 + - blas=1.0=mkl + - brotlipy=0.7.0=py36he6145b8_1001 + - ca-certificates=2020.10.14=0 + - certifi=2020.12.5=py36h06a4308_0 + - cffi=1.14.0=py36h2e261b9_0 + - chardet=3.0.4=py36h9880bd3_1008 + - click=7.1.2=py_0 + - cookiecutter=1.7.2=pyh9f0ad1d_0 + - cryptography=3.2.1=py36h6ec43e4_0 + - cudatoolkit=10.1.243=h6bb024c_0 + - cycler=0.10.0=py_2 + - cython=0.29.21=py36ha357f81_1 + - dataclasses=0.7=pyhe4b4509_6 + - filelock=3.0.12=pyh9f0ad1d_0 + - freetype=2.10.4=h5ab3b9f_0 + - gperftools=2.7=h767d802_2 + - idna=2.10=pyh9f0ad1d_0 + - intel-openmp=2020.2=254 + - jinja2=2.11.2=pyh9f0ad1d_0 + - jinja2-time=0.2.0=py_2 + - joblib=0.17.0=py_0 + - jpeg=9b=h024ee3a_2 + - kiwisolver=1.3.1=py36h51d7077_0 + - lcms2=2.11=h396b838_0 + - libedit=3.1.20191231=h14c3975_1 + - libffi=3.2.1=hf484d3e_1007 + - libgcc-ng=9.1.0=hdf63c60_0 + - libpng=1.6.37=hbc83047_0 + - libstdcxx-ng=9.1.0=hdf63c60_0 + - libtiff=4.1.0=h2733197_1 + - libuv=1.40.0=h7b6447c_0 + - lz4-c=1.9.2=heb0550a_3 + - markupsafe=1.1.1=py36he6145b8_2 + - matplotlib-base=3.3.3=py36he12231b_0 + - mkl=2020.2=256 + - mkl-service=2.3.0=py36he8ac12f_0 + - mkl_fft=1.2.0=py36h23d657b_0 + - mkl_random=1.1.1=py36h0573a6f_0 + - ncurses=6.2=he6710b0_1 + - ninja=1.10.2=py36hff7bd54_0 + - nltk=3.5=py_0 + - numpy=1.19.2=py36h54aff64_0 + - numpy-base=1.19.2=py36hfa32c7d_0 + - olefile=0.46=py36_0 + - openssl=1.1.1h=h7b6447c_0 + - packaging=20.7=pyhd3deb0d_0 + - perl=5.32.0=h36c2ea0_0 + - pillow=8.0.1=py36he98fc37_0 + - pip=20.3.1=py36h06a4308_0 + - poyo=0.5.0=py_0 + - protobuf=3.4.1=py36_0 + - pycocotools=2.0.2=py36h8c4c3a4_1 + - pycparser=2.20=py_2 + - pyopenssl=20.0.0=pyhd8ed1ab_0 + - pyparsing=2.4.7=pyh9f0ad1d_0 + - pysocks=1.7.1=py36h9880bd3_2 + - python=3.6.9=h265db76_0 + - python-dateutil=2.8.1=py_0 + - python-slugify=4.0.1=pyh9f0ad1d_0 + - python_abi=3.6=1_cp36m + - pytorch=1.7.0=py3.6_cuda10.1.243_cudnn7.6.3_0 + - readline=7.0=h7b6447c_5 + - regex=2020.11.13=py36h27cfd23_0 + - requests=2.25.0=pyhd3deb0d_0 + - sacremoses=0.0.43=pyh9f0ad1d_0 + - sentencepiece=0.1.92=py36hdb11119_0 + - setuptools=51.0.0=py36h06a4308_2 + - six=1.15.0=py36h06a4308_0 + - sqlite=3.33.0=h62c20be_0 + - text-unidecode=1.3=py_0 + - tk=8.6.10=hbc83047_0 + - tokenizers=0.9.4=py36h2bc52f9_1 + - torchvision=0.8.1=py36_cu101 + - tornado=6.1=py36h1d69622_0 + - tqdm=4.54.1=pyhd3eb1b0_0 + - transformers=4.0.0=pyhd8ed1ab_0 + - typing_extensions=3.7.4.3=py_0 + - unidecode=1.1.1=py_0 + - urllib3=1.25.11=py_0 + - wheel=0.36.1=pyhd3eb1b0_0 + - whichcraft=0.6.1=py_0 + - xz=5.2.5=h7b6447c_0 + - yaml=0.2.5=h7b6447c_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.5=h9ceee32_0 + - pip: + - pyyaml==5.3.1 +prefix: /home/p0w3r/bin/miniconda3/envs/teran diff --git a/evaluate_utils/compute_relevance.py b/evaluate_utils/compute_relevance.py index aa67bf8..ff2de4b 100644 --- a/evaluate_utils/compute_relevance.py +++ b/evaluate_utils/compute_relevance.py @@ -58,7 +58,7 @@ def get_dataset(config, split): data_name = config['dataset']['name'] if 'coco' in data_name: # COCO custom dataset - dataset = data.CocoDataset(root=roots[split]['img'], json=roots[split]['cap'], ids=ids[split], get_images=False) + dataset = data.CocoDataset(imgs_root=roots[split]['img'], captions_json=roots[split]['cap'], coco_annotation_ids=ids[split], get_images=False) elif 'f8k' in data_name or 'f30k' in data_name: dataset = data.FlickrDataset(root=roots[split]['img'], split=split, json=roots[split]['cap'], get_images=False) return dataset diff --git a/evaluation.py b/evaluation.py index 5cf4569..c815606 100644 --- a/evaluation.py +++ b/evaluation.py @@ -1,78 +1,21 @@ from __future__ import print_function -import numpy - -from data import get_test_loader import time + +import numpy import numpy as np import torch import tqdm -from collections import OrderedDict -from utils import dot_sim, get_model + from evaluate_utils.dcg import DCG from models.loss import order_sim, AlignmentContrastiveLoss - - -class AverageMeter(object): - """Computes and stores the average and current value""" - - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=0): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / (.0001 + self.count) - - def __str__(self): - """String representation for logging - """ - # for values that should be recorded exactly e.g. iteration number - if self.count == 0: - return str(self.val) - # for stats - return '%.4f (%.4f)' % (self.val, self.avg) - - -class LogCollector(object): - """A collection of logging objects that can change from train to val""" - - def __init__(self): - # to keep the order of logged variables deterministic - self.meters = OrderedDict() - - def update(self, k, v, n=0): - # create a new meter if previously not recorded - if k not in self.meters: - self.meters[k] = AverageMeter() - self.meters[k].update(v, n) - - def __str__(self): - """Concatenate the meters in one log line - """ - s = '' - for i, (k, v) in enumerate(self.meters.items()): - if i > 0: - s += ' ' - s += k + ' ' + str(v) - return s - - def tb_log(self, tb_logger, prefix='', step=None): - """Log using tensorboard - """ - for k, v in self.meters.items(): - tb_logger.add_scalar(prefix + k, v.val, global_step=step) +from utils import get_model, AverageMeter, LogCollector +from data import get_coco_image_retrieval_data, get_test_loader def encode_data(model, data_loader, log_step=10, logging=print): - """Encode all images and captions loadable by `data_loader` + """ + Encode all images and captions loadable by `data_loader` """ batch_time = AverageMeter() val_logger = LogCollector() @@ -106,14 +49,13 @@ def encode_data(model, data_loader, log_step=10, logging=print): else: text = targets captions = targets - wembeddings = model.img_txt_enc.txt_enc.word_embeddings(captions.cuda() if torch.cuda.is_available() else captions) # compute the embeddings with torch.no_grad(): _, _, img_emb, cap_emb, cap_length = model.forward_emb(images, text, img_length, cap_length, boxes) # initialize the numpy arrays given the size of the embeddings - if img_embs is None: + if img_embs is None: # N x max_len x 1024 img_embs = torch.zeros((len(data_loader.dataset), max_img_len, img_emb.size(2))) cap_embs = torch.zeros((len(data_loader.dataset), max_cap_len, cap_emb.size(2))) @@ -149,12 +91,14 @@ def encode_data(model, data_loader, log_step=10, logging=print): return img_embs, cap_embs, img_lengths, cap_lengths -def evalrank(config, checkpoint, split='dev', fold5=False): +def evalrank(config, checkpoint, split='dev', fold5=False, eval_t2i=True, eval_i2t=False): """ Evaluate a trained model on either dev or test. If `fold5=True`, 5 fold cross-validation is done (only for MSCOCO). Otherwise, the full data is used for evaluation. """ + evalrank_start_time = time.time() + # load model and options # checkpoint = torch.load(model_path) data_path = config['dataset']['data'] @@ -173,10 +117,15 @@ def evalrank(config, checkpoint, split='dev', fold5=False): ndcg_val_scorer = DCG(config, len(data_loader.dataset), split, rank=25, relevance_methods=['rougeL', 'spice']) # initialize similarity matrix evaluator - sim_matrix_fn = AlignmentContrastiveLoss(aggregation=config['training']['alignment-mode'], return_similarity_mat=True) if config['training']['loss-type'] == 'alignment' else None + sim_matrix_fn = AlignmentContrastiveLoss(aggregation=config['training']['alignment-mode'], + return_similarity_mat=True) if config['training'][ + 'loss-type'] == 'alignment' else None print('Computing results...') + encode_data_start_time = time.time() img_embs, cap_embs, img_lenghts, cap_lenghts = encode_data(model, data_loader) + print(f"Time elapsed for encode_data: {time.time() - encode_data_start_time} seconds.") + torch.cuda.empty_cache() # if checkpoint2 is not None: @@ -195,51 +144,121 @@ def evalrank(config, checkpoint, split='dev', fold5=False): if not fold5: # no cross-validation, full evaluation - r, rt = i2t(img_embs, cap_embs, img_lenghts, cap_lenghts, return_ranks=True, ndcg_scorer=ndcg_val_scorer, sim_function=sim_matrix_fn, cap_batches=5) - ri, rti = t2i(img_embs, cap_embs, img_lenghts, cap_lenghts, return_ranks=True, ndcg_scorer=ndcg_val_scorer, sim_function=sim_matrix_fn, im_batches=5) - ar = (r[0] + r[1] + r[2]) / 3 - ari = (ri[0] + ri[1] + ri[2]) / 3 - rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] - print("rsum: %.1f" % rsum) - print("Average i2t Recall: %.1f" % ar) - print("Image to text: %.1f %.1f %.1f %.1f %.1f, ndcg_rouge=%.4f, ndcg_spice=%.4f" % r) - print("Average t2i Recall: %.1f" % ari) - print("Text to image: %.1f %.1f %.1f %.1f %.1f, ndcg_rouge=%.4f, ndcg_spice=%.4f" % ri) + if eval_i2t: + eval_i2t_start_time = time.time() + + r, rt = i2t(img_embs, + cap_embs, + img_lenghts, + cap_lenghts, + return_ranks=True, + ndcg_scorer=ndcg_val_scorer, + sim_function=sim_matrix_fn, + cap_batches=5) + ar = (r[0] + r[1] + r[2]) / 3 + print("Average i2t Recall: %.1f" % ar) + print("Image to text: %.1f %.1f %.1f %.1f %.1f, ndcg_rouge=%.4f, ndcg_spice=%.4f" % r) + + print(f"Time elapsed for i2t evaluation without 5-fold CV: {time.time() - eval_i2t_start_time} seconds.") + + if eval_t2i: + eval_t2i_start_time = time.time() + + ri, rti = t2i(img_embs, + cap_embs, + img_lenghts, + cap_lenghts, + return_ranks=True, + ndcg_scorer=ndcg_val_scorer, + sim_function=sim_matrix_fn, + im_batches=5) + + ari = (ri[0] + ri[1] + ri[2]) / 3 + print("Average t2i Recall: %.1f" % ari) + print("Text to image: %.1f %.1f %.1f %.1f %.1f, ndcg_rouge=%.4f, ndcg_spice=%.4f" % ri) + + print(f"Time elapsed for t2i evaluation without 5-fold CV: {time.time() - eval_t2i_start_time} seconds.") + + if eval_i2t and eval_t2i: + rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] + print("rsum: %.1f" % rsum) + + + else: # 5fold cross-validation, only for MSCOCO results = [] for i in range(5): - r, rt0 = i2t(img_embs[i * 5000:(i + 1) * 5000], cap_embs[i * 5000:(i + 1) * 5000], - img_lenghts[i * 5000:(i + 1) * 5000], cap_lenghts[i * 5000:(i + 1) * 5000], - return_ranks=True, ndcg_scorer=ndcg_val_scorer, fold_index=i, sim_function=sim_matrix_fn, cap_batches=1) - print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f ndcg_spice=%.4f" % r) - ri, rti0 = t2i(img_embs[i * 5000:(i + 1) * 5000], cap_embs[i * 5000:(i + 1) * 5000], - img_lenghts[i * 5000:(i + 1) * 5000], cap_lenghts[i * 5000:(i + 1) * 5000], - return_ranks=True, ndcg_scorer=ndcg_val_scorer, fold_index=i, sim_function=sim_matrix_fn, im_batches=1) - if i == 0: - rt, rti = rt0, rti0 - print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f, ndcg_spice=%.4f" % ri) - ar = (r[0] + r[1] + r[2]) / 3 - ari = (ri[0] + ri[1] + ri[2]) / 3 - rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] - print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) - results += [list(r) + list(ri) + [ar, ari, rsum]] + if eval_i2t: + r, rt0 = i2t(img_embs[i * 5000:(i + 1) * 5000], cap_embs[i * 5000:(i + 1) * 5000], + img_lenghts[i * 5000:(i + 1) * 5000], cap_lenghts[i * 5000:(i + 1) * 5000], + return_ranks=True, ndcg_scorer=ndcg_val_scorer, fold_index=i, sim_function=sim_matrix_fn, + cap_batches=1) + print("Image to text: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f ndcg_spice=%.4f" % r) + if i == 0: + rt = rt0 + ar = (r[0] + r[1] + r[2]) / 3 + if eval_t2i: + ri, rti0 = t2i(img_embs[i * 5000:(i + 1) * 5000], cap_embs[i * 5000:(i + 1) * 5000], + img_lenghts[i * 5000:(i + 1) * 5000], cap_lenghts[i * 5000:(i + 1) * 5000], + return_ranks=True, ndcg_scorer=ndcg_val_scorer, fold_index=i, sim_function=sim_matrix_fn, + im_batches=1) + print("Text to image: %.1f, %.1f, %.1f, %.1f, %.1f, ndcg_rouge=%.4f, ndcg_spice=%.4f" % ri) + if i == 0: + rti = rti0 + ari = (ri[0] + ri[1] + ri[2]) / 3 + + + if eval_t2i and eval_i2t: + rsum = r[0] + r[1] + r[2] + ri[0] + ri[1] + ri[2] + print("rsum: %.1f ar: %.1f ari: %.1f" % (rsum, ar, ari)) + elif eval_t2i: + print("ari: %.1f" % (ari,)) + elif eval_i2t: + print("ar: %.1f" % (ar,)) + + if eval_t2i and eval_i2t: + results += [list(r) + list(ri) + [ar, ari, rsum]] # 7 + 7 + 3 = 17 elements + elif eval_t2i: + results += [list(ri) + [ari]] # 7 + 1 = 8 elements + elif eval_i2t: + results += [list(r) + [ar]] # 7 + 1 = 8 elements print("-----------------------------------") print("Mean metrics: ") mean_metrics = tuple(np.array(results).mean(axis=0).flatten()) - print("rsum: %.1f" % (mean_metrics[16] * 6)) - print("Average i2t Recall: %.1f" % mean_metrics[14]) - print("Image to text: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f" % - mean_metrics[:7]) - print("Average t2i Recall: %.1f" % mean_metrics[15]) - print("Text to image: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f" % - mean_metrics[7:14]) - - torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar') - - -def i2t(images, captions, img_lenghts, cap_lenghts, npts=None, return_ranks=False, ndcg_scorer=None, fold_index=0, measure='dot', sim_function=None, cap_batches=1): + if eval_t2i and eval_i2t: + print("rsum: %.1f" % (mean_metrics[16] * 6)) + print("Average i2t Recall: %.1f" % mean_metrics[14]) + print("Image to text: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f" % + mean_metrics[:7]) + print("Average t2i Recall: %.1f" % mean_metrics[15]) + print("Text to image: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f" % + mean_metrics[7:14]) + elif eval_t2i: + print("Average t2i Recall: %.1f" % mean_metrics[7]) + print("Text to image: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f" % + mean_metrics[:7]) + elif eval_i2t: + print("Average i2t Recall: %.1f" % mean_metrics[7]) + print("Image to text: %.1f %.1f %.1f %.1f %.1f ndcg_rouge=%.4f ndcg_spice=%.4f" % + mean_metrics[:7]) + + + + + if eval_t2i and eval_i2t: + torch.save({'rt': rt, 'rti': rti}, 'ranks.pth.tar') + elif eval_t2i: + torch.save({'rti': rti}, 'ranks.pth.tar') + elif eval_i2t: + torch.save({'rt': rt}, 'ranks.pth.tar') + + print(f"Time elapsed for evalrank(): {time.time() - evalrank_start_time} seconds.") + + +def i2t(images, captions, img_lenghts, cap_lenghts, npts=None, return_ranks=False, ndcg_scorer=None, fold_index=0, + measure='dot', sim_function=None, cap_batches=1): """ Images->Text (Image Annotation) Images: (5N, K) matrix of images @@ -281,8 +300,8 @@ def i2t(images, captions, img_lenghts, cap_lenghts, npts=None, return_ranks=Fals d = d.cpu().numpy().flatten() else: for i in range(cap_batches): - captions_now = captions[i*captions_per_batch:(i+1)*captions_per_batch] - cap_lenghts_now = cap_lenghts[i*captions_per_batch:(i+1)*captions_per_batch] + captions_now = captions[i * captions_per_batch:(i + 1) * captions_per_batch] + cap_lenghts_now = cap_lenghts[i * captions_per_batch:(i + 1) * captions_per_batch] captions_now = captions_now.cuda() d_align = sim_function(im, captions_now, im_len, cap_lenghts_now) @@ -290,7 +309,7 @@ def i2t(images, captions, img_lenghts, cap_lenghts, npts=None, return_ranks=Fals # d_matching = torch.mm(im[:, 0, :], captions[:, 0, :].t()) # d_matching = d_matching.cpu().numpy().flatten() if d is None: - d = d_align # + d_matching + d = d_align # + d_matching else: d = numpy.concatenate([d, d_align], axis=0) @@ -325,7 +344,8 @@ def i2t(images, captions, img_lenghts, cap_lenghts, npts=None, return_ranks=Fals return (r1, r5, r10, medr, meanr, mean_rougel_ndcg, mean_spice_ndcg) -def t2i(images, captions, img_lenghts, cap_lenghts, npts=None, return_ranks=False, ndcg_scorer=None, fold_index=0, measure='dot', sim_function=None, im_batches=1): +def t2i(images, captions, img_lenghts, cap_lenghts, npts=None, return_ranks=False, ndcg_scorer=None, fold_index=0, + measure='dot', sim_function=None, im_batches=1): """ Text->Images (Image Search) Images: (5N, K) matrix of images @@ -370,25 +390,27 @@ def t2i(images, captions, img_lenghts, cap_lenghts, npts=None, return_ranks=Fals d = d.cpu().numpy() else: for i in range(im_batches): - ims_now = ims[i * images_per_batch:(i+1) * images_per_batch] - ims_len_now = ims_len[i * images_per_batch:(i+1) * images_per_batch] + ims_now = ims[i * images_per_batch:(i + 1) * images_per_batch] + ims_len_now = ims_len[i * images_per_batch:(i + 1) * images_per_batch] ims_now = ims_now.cuda() # d = numpy.dot(queries, ims.T) + # d_align is the (MrSw) aggregated/pooled similarity matrix A in the paper d_align = sim_function(ims_now, queries, ims_len_now, queries_len).t() d_align = d_align.cpu().numpy() # d_matching = torch.mm(queries[:, 0, :], ims[:, 0, :].t()) # d_matching = d_matching.cpu().numpy() if d is None: - d = d_align # + d_matching + d = d_align # + d_matching else: d = numpy.concatenate([d, d_align], axis=1) + # d contains all aggregated/pooled similarity matrices for all query-image pairs in the test set inds = numpy.zeros(d.shape) for i in range(len(inds)): inds[i] = numpy.argsort(d[i])[::-1] - ranks[5 * index + i] = numpy.where(inds[i] == index)[0][ - 0] # in che posizione e' l'immagine (index) che ha questa caption (5*index + i) + # in che posizione e' l'immagine (index) che ha questa caption (5*index + i) + ranks[5 * index + i] = numpy.where(inds[i] == index)[0][0] top50[5 * index + i] = inds[i][0:50] # calculate ndcg if ndcg_scorer is not None: diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..56e026d --- /dev/null +++ b/inference.py @@ -0,0 +1,272 @@ +import argparse +import os +import sys +import time +from pathlib import Path +from typing import List + +import numpy as np +import torch +import tqdm +import yaml + +from data import get_coco_image_retrieval_data, QueryEncoder +from models.loss import AlignmentContrastiveLoss +from models.teran import TERAN +from utils import AverageMeter, LogCollector + + +def persist_img_embs(config, data_loader, dataset_indices, numpy_img_emb): + dst_root = Path(os.getcwd()).joinpath(config['image-retrieval']['pre_computed_img_embeddings_root']) + if not dst_root.exists(): + dst_root.mkdir(parents=True, exist_ok=True) + + assert len(dataset_indices) == len(numpy_img_emb) + img_names = get_image_names(dataset_indices, data_loader) + # TODO do we want to store them in one big npz? + for idx in range(len(img_names)): + dst = dst_root.joinpath(img_names[idx] + '.npz') + if dst.exists(): + continue + np.savez_compressed(str(dst), img_emb=numpy_img_emb[idx]) + + +def encode_data_for_inference(model: TERAN, data_loader, log_step=10, logging=print, pre_compute_img_embs=False): + # compute the embedding vectors v_i, s_j (paper) for each image region and word respectively + # -> forwarding the data through the respective TE stacks + print( + f'{"Pre-" if pre_compute_img_embs else ""}Computing image {"" if pre_compute_img_embs else "and query "}embeddings...') + + # we don't need autograd for inference + model.eval() + + # array to keep all the embeddings + # TODO maybe we can store those embeddings in an index and load it instead of computing each time for each query + query_embs = None + num_query_feats = None + num_img_feats = None # all images have a fixed size of pre-extracted features of 36 + 1 regions + img_embs = None + + # make sure val logger is used + batch_time = AverageMeter() + val_logger = LogCollector() + model.logger = val_logger + + start_time = time.time() + for i, (img_feature_batch, img_feat_bboxes_batch, img_feat_len_batch, query_token_batch, query_len_batch, + dataset_indices) in enumerate(data_loader): + batch_start_time = time.time() + """ + the data loader returns None values for the respective batches if the only query was already loaded + -> query_token_batch, query_len_batch = None, None + """ + + with torch.no_grad(): + # compute the query embedding only in the first iteration (also because there is only 1 query in IR) + if query_embs is None and not pre_compute_img_embs: + # TODO maybe we can get the most matching roi from query_emb_aggr? + query_emb_aggr, query_emb, _ = model.forward_txt(query_token_batch, query_len_batch) + + # store results as np arrays for further processing or persisting + num_query_feats = query_len_batch[0] if isinstance(query_len_batch, list) else query_len_batch + query_feat_dim = query_emb.size(2) + query_embs = torch.zeros((1, num_query_feats, query_feat_dim)) + query_embs[0, :, :] = query_emb.cpu().permute(1, 0, 2) + + # compute every image embedding in the dataset + img_emb_aggr, img_emb = model.forward_img(img_feature_batch, img_feat_len_batch, img_feat_bboxes_batch) + + # init array to store results for further processing or persisting + if img_embs is None: + num_img_feats = img_feat_len_batch[0] if isinstance(img_feat_len_batch, + list) else img_feat_len_batch + img_feat_dim = img_emb.size(2) + img_embs = torch.zeros((len(data_loader.dataset), num_img_feats, img_feat_dim)) + + numpy_img_emb = img_emb.cpu().permute(1, 0, 2) # why are we permuting here? -> TERAN + img_embs[dataset_indices, :, :] = numpy_img_emb + if pre_compute_img_embs: + # if we are in a pre-compute run, persist the arrays + persist_img_embs(model_config, data_loader, dataset_indices, numpy_img_emb) + + # measure elapsed time per batch + batch_time.update(time.time() - batch_start_time) + + if i % log_step == 0: + logging( + f"Batch: [{i}/{len(data_loader)}]\t{str(model.logger)}\tTime {batch_time.val:.3f} ({batch_time.avg:.3f})") + del img_feature_batch, query_token_batch + + print( + f"Time elapsed to {'encode' if not pre_compute_img_embs else 'encode and persist'} data: {time.time() - start_time} seconds.") + return img_embs, query_embs, num_img_feats, num_query_feats + + +def compute_distances(img_embs, query_embs, img_lengths, query_lengths, config): + # initialize similarity matrix evaluator + sim_matrix_fn = AlignmentContrastiveLoss(aggregation=config['image-retrieval']['alignment_mode'], + return_similarity_mat=True) + start_time = time.time() + img_emb_batches = 1 # TODO config / calc + img_embs_per_batch = img_embs.size(0) // img_emb_batches # TODO config variable + + # distances storage + distances = None + + # since its always the same query we can reuse the batch + # (TODO maybe we can even just use a batch of size 1?! -> check the sim_matrix_fn) + query_emb_batch = query_embs[:1] + query_length_batch = [query_lengths[0] if isinstance(query_lengths, list) else query_lengths for _ in range(1)] + query_emb_batch.cuda() + + # batch-wise compute the alignment distance between the images and the query + for i in tqdm.trange(img_emb_batches): + # create the current batch + img_embs_batch = img_embs[i * img_embs_per_batch:(i + 1) * img_embs_per_batch] + img_embs_length_batch = [img_lengths for _ in range(img_embs_per_batch)] + img_embs_batch.cuda() + + # compute and pool the similarity matrices to get the global distance between the image and the query + alignment_distance = sim_matrix_fn(img_embs_batch, query_emb_batch, img_embs_length_batch, query_length_batch) + alignment_distance = alignment_distance.t().cpu().numpy() + + # store the distances + if distances is None: + distances = alignment_distance + else: + distances = np.concatenate([distances, alignment_distance], axis=1) + + # get the img indices descended sorted by the distance matrix + sorted_distance_indices = np.argsort(distances.squeeze())[::-1] + print(f"Time elapsed to compute and pool the similarity matrices: {time.time() - start_time} seconds.") + return sorted_distance_indices + + +def get_image_names(dataset_indices, dataset) -> List[str]: + return [dataset.get_image_metadata(idx)[1]['file_name'] for idx in dataset_indices] + + +def load_precomputed_image_embeddings(config, num_workers): + print("Loading pre-computed image embeddings...") + start = time.time() + # returns a PreComputedCocoImageEmbeddingsDataset + dataset = get_coco_image_retrieval_data(config, num_workers=num_workers) + + # get the img embeddings and convert them to Tensors + np_img_embs = np.array(list(dataset.img_embs.values())) + img_embs = torch.Tensor(np_img_embs) + img_lengths = len(np_img_embs[0]) + print(f"Time elapsed to load pre-computed embeddings and compute query embedding: {time.time() - start} seconds!") + return img_embs, img_lengths, dataset + + +def load_teran(config, checkpoint): + # construct model + model = TERAN(config) + # load model state + model.load_state_dict(checkpoint['model'], strict=False) + return model + + +def top_k_image_retrieval(opts, config, checkpoint) -> List[str]: + model = load_teran(config, checkpoint) + + use_precomputed_img_embeddings = config['image-retrieval']['use_precomputed_img_embeddings'] + if use_precomputed_img_embeddings: + # load pre computed img embs + img_embs, img_lengths, dataset = load_precomputed_image_embeddings(config, num_workers=opts.num_data_workers) + # compute query emb + query_encoder = QueryEncoder(config, model) + query_embs, query_lengths = query_encoder.compute_query_embedding(opts.query) + + else: + # returns a Dataloader of a PreComputedCocoFeaturesDataset + data_loader = get_coco_image_retrieval_data(config, + query=opts.query, + num_workers=opts.num_data_workers) + dataset = data_loader.dataset + # encode the data (i.e. compute the embeddings / TE outputs for the images and query) + img_embs, query_embs, img_lengths, query_lengths = encode_data_for_inference(model, data_loader) + + if opts.device == "cuda": + torch.cuda.empty_cache() + + print(f"Images Embeddings: {img_embs.shape[0]}, Query Embeddings: {query_embs.shape[0]}") + + # compute the matching scores + distance_sorted_indices = compute_distances(img_embs, query_embs, img_lengths, query_lengths, config) + top_k_indices = distance_sorted_indices[:opts.top_k] + + # get the image names + top_k_images = get_image_names(top_k_indices, dataset) + return top_k_images + + +def prepare_model_checkpoint_and_config(opts): + checkpoint = torch.load(opts.model, map_location=torch.device(opts.device)) + print('Checkpoint loaded from {}'.format(opts.model)) + model_checkpoint_config = checkpoint['config'] + + with open(opts.config, 'r') as yml_file: + loaded_config = yaml.load(yml_file) + # Override some mandatory things in the configuration + model_checkpoint_config['dataset']['images-path'] = loaded_config['dataset']['images-path'] + model_checkpoint_config['dataset']['data'] = loaded_config['dataset']['data'] + model_checkpoint_config['image-retrieval'] = loaded_config['image-retrieval'] + + return model_checkpoint_config, checkpoint + + +def pre_compute_img_embeddings(opts, config, checkpoint): + # construct model + model = TERAN(config) + + # load model state + model.load_state_dict(checkpoint['model'], strict=False) + + print('Loading dataset') + data_loader = get_coco_image_retrieval_data(config, + query=opts.query, + num_workers=opts.num_data_workers, + pre_compute_img_embs=True) + + # encode the data (i.e. compute the embeddings / TE outputs for the images and query) + encode_data_for_inference(model, data_loader, pre_compute_img_embs=True) + + +if __name__ == '__main__': + print("CUDA_VISIBLE_DEVICES: " + os.getenv("CUDA_VISIBLE_DEVICES", "NOT SET - ABORTING")) + if os.getenv("CUDA_VISIBLE_DEVICES", None) is None: + sys.exit(1) + + parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, + help="Model (checkpoint) to load. E.g. pretrained_models/coco_MrSw.pth.tar", required=True) + parser.add_argument('--pre_compute_img_embeddings', action='store_true', help="If set or true, the image " + "embeddings get precomputed and " + "persisted at the directory " + "specified in the config.") + parser.add_argument('--query', type=str, required='--pre_compute_img_embeddings' not in sys.argv) + parser.add_argument('--num_data_workers', type=int, default=8) + parser.add_argument('--num_images', type=int, default=5000) + parser.add_argument('--top_k', type=int, default=100) + parser.add_argument('--dataset', type=str, choices=['coco'], default='coco') # TODO support other datasets + parser.add_argument('--config', type=str, default='configs/teran_coco_MrSw_IR.yaml', help="Which configuration to " + "use for overriding the" + " checkpoint " + "configuration. See " + "into 'config' folder") + # cpu is only for local test runs + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') + opts = parser.parse_args() + + model_config, model_checkpoint = prepare_model_checkpoint_and_config(opts) + + if not opts.pre_compute_img_embeddings: + top_k_matches = top_k_image_retrieval(opts, model_config, model_checkpoint) + print(f"##########################################") + print(f"QUERY: {opts.query}") + print(f"######## TOP {opts.top_k} RESULTS ########") + print(top_k_matches) + else: + pre_compute_img_embeddings(opts, model_config, model_checkpoint) diff --git a/models/teran.py b/models/teran.py index 1eeb524..b57be52 100644 --- a/models/teran.py +++ b/models/teran.py @@ -1,23 +1,23 @@ import torch -import torch.nn.init +import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.functional as F -import torch.backends.cudnn as cudnn +import torch.nn.init +from nltk.corpus import stopwords from transformers import BertTokenizer -from models.loss import ContrastiveLoss, PermInvMatchingLoss, AlignmentContrastiveLoss -from models.text import EncoderTextBERT, EncoderText -from models.visual import TransformerPostProcessing, EncoderImage - -from .utils import l2norm, PositionalEncodingImageBoxes, PositionalEncodingText, Aggregator, generate_square_subsequent_mask -from nltk.corpus import stopwords, words as nltk_words +from models.loss import ContrastiveLoss, AlignmentContrastiveLoss +from models.text import EncoderText +from models.visual import EncoderImage +from .utils import l2norm, Aggregator class JointTextImageTransformerEncoder(nn.Module): """ This is a bert caption encoder - transformer image encoder (using bottomup features). - If process the encoder outputs through a transformer, like VilBERT and outputs two different graph embeddings + It process the encoder outputs through a transformer, like VilBERT and outputs two different graph embeddings """ + def __init__(self, config): super().__init__() self.txt_enc = EncoderText(config) @@ -36,8 +36,8 @@ def __init__(self, config): self.shared_transformer = config['model']['shared-transformer'] transformer_layer_1 = nn.TransformerEncoderLayer(d_model=embed_size, nhead=4, - dim_feedforward=2048, - dropout=dropout, activation='relu') + dim_feedforward=2048, + dropout=dropout, activation='relu') self.transformer_encoder_1 = nn.TransformerEncoder(transformer_layer_1, num_layers=layers) if not self.shared_transformer: @@ -51,77 +51,81 @@ def __init__(self, config): self.text_aggregation_type = config['model']['text-aggregation'] self.img_aggregation_type = config['model']['image-aggregation'] - def forward(self, features, captions, feat_len, cap_len, boxes): + def forward_txt(self, captions, cap_len): # process captions by using bert - full_cap_emb_aggr, c_emb = self.txt_enc(captions, cap_len) # B x S x cap_dim - - # process image regions using a two-layer transformer - full_img_emb_aggr, i_emb = self.img_enc(features, feat_len, boxes) # B x S x vis_dim - # i_emb = i_emb.permute(1, 0, 2) # B x S x vis_dim - - bs = features.shape[0] - - # if False: - # # concatenate the embeddings together - # max_summed_lengths = max([x + y for x, y in zip(feat_len, cap_len)]) - # i_c_emb = torch.zeros(bs, max_summed_lengths, self.embed_size) - # i_c_emb = i_c_emb.to(features.device) - # mask = torch.zeros(bs, max_summed_lengths).bool() - # mask = mask.to(features.device) - # for i_c, m, i, c, i_len, c_len in zip(i_c_emb, mask, i_emb, c_emb, feat_len, cap_len): - # i_c[:c_len] = c[:c_len] - # i_c[c_len:c_len + i_len] = i[:i_len] - # m[c_len + i_len:] = True - # - # i_c_emb = i_c_emb.permute(1, 0, 2) # S_vis + S_txt x B x dim - # out = self.transformer_encoder(i_c_emb, src_key_padding_mask=mask) # S_vis + S_txt x B x dim - # - # full_cap_emb = out[0, :, :] - # I = torch.LongTensor(cap_len).view(1, -1, 1) - # I = I.expand(1, bs, self.embed_size).to(features.device) - # full_img_emb = torch.gather(out, dim=0, index=I).squeeze(0) - # else: + full_cap_emb_aggr, c_emb = self.txt_enc(captions, cap_len) # B x S x cap_dim # forward the captions if self.text_aggregation_type is not None: c_emb = self.cap_proj(c_emb) - mask = torch.zeros(bs, max(cap_len)).bool() - mask = mask.to(features.device) + cap_bs = captions.shape[0] + mask = torch.zeros(cap_bs, max(cap_len)).bool() + mask = mask.to(captions.device) for m, c_len in zip(mask, cap_len): m[c_len:] = True - full_cap_emb = self.transformer_encoder_1(c_emb.permute(1, 0, 2), src_key_padding_mask=mask) # S_txt x B x dim + full_cap_emb = self.transformer_encoder_1(c_emb.permute(1, 0, 2), + src_key_padding_mask=mask) # S_txt x B x dim full_cap_emb_aggr = self.text_aggregation(full_cap_emb, cap_len, mask) + + full_cap_emb_aggr = l2norm(full_cap_emb_aggr) + + # normalize even every vector of the set + full_cap_emb = F.normalize(full_cap_emb, p=2, dim=2) # else use the embedding output by the txt model else: full_cap_emb = None + if self.order_embeddings: + full_cap_emb_aggr = torch.abs(full_cap_emb_aggr) + + return full_cap_emb_aggr, full_cap_emb + + def forward_img(self, features, feat_len, boxes): + # process image regions using a two-layer transformer + full_img_emb_aggr, i_emb = self.img_enc(features, feat_len, boxes) # B x S x vis_dim # forward the regions + if self.img_aggregation_type is not None: i_emb = self.img_proj(i_emb) - mask = torch.zeros(bs, max(feat_len)).bool() + feat_bs = features.shape[0] + mask = torch.zeros(feat_bs, max(feat_len)).bool() mask = mask.to(features.device) for m, v_len in zip(mask, feat_len): m[v_len:] = True if self.shared_transformer: - full_img_emb = self.transformer_encoder_1(i_emb.permute(1, 0, 2), src_key_padding_mask=mask) # S_txt x B x dim + full_img_emb = self.transformer_encoder_1(i_emb.permute(1, 0, 2), + src_key_padding_mask=mask) # S_img x B x dim else: - full_img_emb = self.transformer_encoder_2(i_emb.permute(1, 0, 2), src_key_padding_mask=mask) # S_txt x B x dim + full_img_emb = self.transformer_encoder_2(i_emb.permute(1, 0, 2), + src_key_padding_mask=mask) # S_img x B x dim full_img_emb_aggr = self.image_aggregation(full_img_emb, feat_len, mask) + + full_img_emb_aggr = l2norm(full_img_emb_aggr) + # normalize even every vector of the set + full_img_emb = F.normalize(full_img_emb, p=2, dim=2) else: full_img_emb = None - full_cap_emb_aggr = l2norm(full_cap_emb_aggr) - full_img_emb_aggr = l2norm(full_img_emb_aggr) - - # normalize even every vector of the set - full_img_emb = F.normalize(full_img_emb, p=2, dim=2) - full_cap_emb = F.normalize(full_cap_emb, p=2, dim=2) - if self.order_embeddings: - full_cap_emb_aggr = torch.abs(full_cap_emb_aggr) full_img_emb_aggr = torch.abs(full_img_emb_aggr) + + return full_img_emb_aggr, full_img_emb + + def forward(self, features, captions, feat_len, cap_len, boxes): + if captions is not None: + # process captions + full_cap_emb_aggr, full_cap_emb = self.forward_txt(captions, cap_len) + else: + full_cap_emb_aggr, full_cap_emb = None, None + + if features is not None: + # process image regions + full_img_emb_aggr, full_img_emb = self.forward_img(features, feat_len, boxes) + else: + full_img_emb_aggr, full_img_emb = None, None + return full_img_emb_aggr, full_cap_emb_aggr, full_img_emb, full_cap_emb @@ -145,7 +149,8 @@ def __init__(self, config): if 'alignment' in loss_type: self.alignment_criterion = AlignmentContrastiveLoss(margin=config['training']['margin'], measure=config['training']['measure'], - max_violation=config['training']['max-violation'], aggregation=config['training']['alignment-mode']) + max_violation=config['training']['max-violation'], + aggregation=config['training']['alignment-mode']) if 'matching' in loss_type: self.matching_criterion = ContrastiveLoss(margin=config['training']['margin'], measure=config['training']['measure'], @@ -180,32 +185,61 @@ def __init__(self, config): # self.img_enc.eval() # self.txt_enc.eval() + def remove_stopwords(self, captions, cap_feats, cap_len): + # remove stopwords + # keep only word indexes that are not stopwords + good_word_indexes = [[i for i, (tok, w) in enumerate(zip(self.tokenizer.convert_ids_to_tokens(ids), ids)) if + tok not in self.en_stops or w == 0] for ids in captions] # keeps the padding + cap_len = [len(w) - (cap_feats.shape[0] - orig_len) for w, orig_len in zip(good_word_indexes, cap_len)] + min_cut_len = min([len(w) for w in good_word_indexes]) + good_word_indexes = [words[:min_cut_len] for words in good_word_indexes] + good_word_indexes = torch.LongTensor(good_word_indexes).to(cap_feats.device) # B x S + good_word_indexes = good_word_indexes.t().unsqueeze(2).expand(-1, -1, cap_feats.shape[2]) # S x B x dim + cap_feats = cap_feats.gather(dim=0, index=good_word_indexes) + + return cap_feats, cap_len + def forward_emb(self, images, captions, img_len, cap_len, boxes): - """Compute the image and caption embeddings + """ + Compute the image and caption embeddings """ # Set mini-batch dataset if torch.cuda.is_available(): - images = images.cuda() - captions = captions.cuda() - boxes = boxes.cuda() + if images is not None and boxes is not None: + images = images.cuda() + boxes = boxes.cuda() + if captions is not None: + captions = captions.cuda() # Forward img_emb_aggr, cap_emb_aggr, img_feats, cap_feats = self.img_txt_enc(images, captions, img_len, cap_len, boxes) - if self.tokenizer is not None: - # remove stopwords - # keep only word indexes that are not stopwords - good_word_indexes = [[i for i, (tok, w) in enumerate(zip(self.tokenizer.convert_ids_to_tokens(ids), ids)) if - tok not in self.en_stops or w == 0] for ids in captions] # keeps the padding - cap_len = [len(w) - (cap_feats.shape[0] - orig_len) for w, orig_len in zip(good_word_indexes, cap_len)] - min_cut_len = min([len(w) for w in good_word_indexes]) - good_word_indexes = [words[:min_cut_len] for words in good_word_indexes] - good_word_indexes = torch.LongTensor(good_word_indexes).to(cap_feats.device) # B x S - good_word_indexes = good_word_indexes.t().unsqueeze(2).expand(-1, -1, cap_feats.shape[2]) # S x B x dim - cap_feats = cap_feats.gather(dim=0, index=good_word_indexes) + if self.tokenizer is not None and captions is not None: + cap_feats, cap_len = self.remove_stopwords(captions, cap_feats, cap_len) return img_emb_aggr, cap_emb_aggr, img_feats, cap_feats, cap_len + def forward_txt(self, captions, cap_len): + """ + compute txt embeddings only + """ + if torch.cuda.is_available(): + captions = captions.cuda() + cap_emb_aggr, cap_feats = self.img_txt_enc.forward_txt(captions, cap_len) + if self.tokenizer is not None and captions is not None: + cap_feats, cap_len = self.remove_stopwords(captions, cap_feats, cap_len) + return cap_emb_aggr, cap_feats, cap_len + + def forward_img(self, images, img_len, boxes): + """ + compute img embeddings only + """ + if torch.cuda.is_available(): + images = images.cuda() + boxes = boxes.cuda() + img_emb_aggr, img_feats = self.img_txt_enc.forward_img(images, img_len, boxes) + return img_emb_aggr, img_feats + def get_parameters(self): lr_multiplier = 1.0 if self.config['text-model']['fine-tune'] else 0.0 @@ -233,7 +267,7 @@ def forward_loss(self, img_emb, cap_emb, img_emb_set, cap_emb_seq, img_lengths, # bs = img_emb.shape[0] losses = {} - if 'matching' in self.config['training']['loss-type']: + if 'matching' in self.config['training']['loss-type']: matching_loss = self.matching_criterion(img_emb, cap_emb) losses.update({'matching-loss': matching_loss}) self.logger.update('matching_loss', matching_loss.item(), img_emb.size(0)) @@ -262,10 +296,15 @@ def forward(self, images, targets, img_lengths, cap_lengths, boxes=None, ids=Non else: text = targets captions = targets - wembeddings = self.img_txt_enc.txt_enc.word_embeddings(captions.cuda() if torch.cuda.is_available() else captions) + wembeddings = self.img_txt_enc.txt_enc.word_embeddings( + captions.cuda() if torch.cuda.is_available() else captions) # compute the embeddings - img_emb_aggr, cap_emb_aggr, img_feats, cap_feats, cap_lengths = self.forward_emb(images, text, img_lengths, cap_lengths, boxes) + img_emb_aggr, cap_emb_aggr, img_feats, cap_feats, cap_lengths = self.forward_emb(images, + text, + img_lengths, + cap_lengths, + boxes) # NOTE: img_feats and cap_feats are S x B x dim loss_dict = self.forward_loss(img_emb_aggr, cap_emb_aggr, img_feats, cap_feats, img_lengths, cap_lengths) diff --git a/models/text.py b/models/text.py index 0dac895..10d23b0 100644 --- a/models/text.py +++ b/models/text.py @@ -58,7 +58,7 @@ def forward(self, x, lengths): # Reshape *final* output to (batch_size, hidden_size) padded = pad_packed_sequence(out, batch_first=True) I = torch.LongTensor(lengths).view(-1, 1, 1) - I = (I.expand(x.size(0), 1, self.embed_size)-1).to(x.device) + I = (I.expand(x.size(0), 1, self.embed_size) - 1).to(x.device) out = torch.gather(padded[0], 1, I).squeeze(1) # normalization in the joint embedding space @@ -105,6 +105,8 @@ def forward(self, x, lengths): lengths: tensor of lengths (LongTensor) of size B ''' if not self.preextracted or self.post_transformer_layers > 0: + # this code builds the attention_mask so that its 1 for every valid token and pads 0 for the max len + # attention_mask is a kinda padding max_len = max(lengths) attention_mask = torch.ones(x.shape[0], max_len) for e, l in zip(attention_mask, lengths): @@ -115,7 +117,8 @@ def forward(self, x, lengths): outputs = x else: outputs = self.bert_model(x, attention_mask=attention_mask) - outputs = outputs[2][-1] + # https://huggingface.co/transformers/model_doc/bert.html#bertmodel + outputs = outputs[2][-1] # -> hidden_states[-1] if self.post_transformer_layers > 0: outputs = outputs.permute(1, 0, 2) @@ -124,7 +127,7 @@ def forward(self, x, lengths): if self.mean: x = outputs.mean(dim=1) else: - x = outputs[:, 0, :] # from the last layer take only the first word + x = outputs[:, 0, :] # from the last layer take only the first word out = self.map(x) diff --git a/models/utils.py b/models/utils.py index 4f32bd4..3f0bac3 100644 --- a/models/utils.py +++ b/models/utils.py @@ -87,7 +87,8 @@ def forward(self, x, boxes): # x is seq_len x B x dim def l2norm(X): - """L2-normalize columns of X + """ + L2-normalize columns of X """ norm = torch.pow(X, 2).sum(dim=1, keepdim=True).sqrt() X = torch.div(X, norm) diff --git a/test.py b/test.py index 9c38df3..3effe87 100644 --- a/test.py +++ b/test.py @@ -1,13 +1,21 @@ import argparse +import os +import sys -import evaluation -import yaml import torch +import yaml + +import evaluation + def main(opt, current_config): model_checkpoint = opt.checkpoint - checkpoint = torch.load(model_checkpoint) + if opt.gpu: + checkpoint = torch.load(model_checkpoint) # , map_location=torch.device("cpu")) + else: + checkpoint = torch.load(model_checkpoint, map_location=torch.device("cpu")) + print('Checkpoint loaded from {}'.format(model_checkpoint)) loaded_config = checkpoint['config'] @@ -22,15 +30,28 @@ def main(opt, current_config): if current_config is not None: loaded_config['dataset']['images-path'] = current_config['dataset']['images-path'] loaded_config['dataset']['data'] = current_config['dataset']['data'] - loaded_config['image-model']['pre-extracted-features-root'] = current_config['image-model']['pre-extracted-features-root'] + loaded_config['image-model']['pre-extracted-features-root'] = current_config['image-model'][ + 'pre-extracted-features-root'] + loaded_config['training']['bs'] = current_config['training']['bs'] + + evaluation.evalrank(loaded_config, checkpoint, split="test", fold5=False, eval_t2i=opt.t2i, eval_i2t=opt.i2t) - evaluation.evalrank(loaded_config, checkpoint, split="test", fold5=fold5) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('checkpoint', type=str, help="Checkpoint to load") parser.add_argument('--size', type=str, choices=['1k', '5k'], default='1k') - parser.add_argument('--config', type=str, default=None, help="Which configuration to use for overriding the checkpoint configuration. See into 'config' folder") + parser.add_argument('--gpu', type=bool, default=True, help="If false, CPU is used for computations; GPU otherwise.") + parser.add_argument('--t2i', action='store_true', default=True, + help="If set text-to-image (image retrieval) evaluation will be executed.") + parser.add_argument('--i2t', action='store_true', default=False, + help="If set image-to-text (image captioning) evaluation will be executed.") + parser.add_argument('--config', type=str, default=None, help="Which configuration to use for overriding the " + "checkpoint configuration. See into 'config' folder") + + print("CUDA_VISIBLE_DEVICES: " + os.getenv("CUDA_VISIBLE_DEVICES", "NOT SET - ABORTING")) + if os.getenv("CUDA_VISIBLE_DEVICES", None) is None: + sys.exit(1) opt = parser.parse_args() if opt.config is not None: @@ -38,4 +59,4 @@ def main(opt, current_config): config = yaml.load(ymlfile) else: config = None - main(opt, config) \ No newline at end of file + main(opt, config) diff --git a/utils.py b/utils.py index 1f2cb6a..822e782 100644 --- a/utils.py +++ b/utils.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + import numpy from models.teran import TERAN @@ -16,3 +18,61 @@ def cosine_sim(x, y): x = x / numpy.expand_dims(numpy.linalg.norm(x, axis=1), 1) y = y / numpy.expand_dims(numpy.linalg.norm(y, axis=1), 1) return numpy.dot(x, y.T) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=0): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / (.0001 + self.count) + + def __str__(self): + """String representation for logging + """ + # for values that should be recorded exactly e.g. iteration number + if self.count == 0: + return str(self.val) + # for stats + return '%.4f (%.4f)' % (self.val, self.avg) + + +class LogCollector(object): + """A collection of logging objects that can change from train to val""" + + def __init__(self): + # to keep the order of logged variables deterministic + self.meters = OrderedDict() + + def update(self, k, v, n=0): + # create a new meter if previously not recorded + if k not in self.meters: + self.meters[k] = AverageMeter() + self.meters[k].update(v, n) + + def __str__(self): + """Concatenate the meters in one log line + """ + s = '' + for i, (k, v) in enumerate(self.meters.items()): + if i > 0: + s += ' ' + s += k + ' ' + str(v) + return s + + def tb_log(self, tb_logger, prefix='', step=None): + """Log using tensorboard + """ + for k, v in self.meters.items(): + tb_logger.add_scalar(prefix + k, v.val, global_step=step)