diff --git a/README.md b/README.md index 2bf39f1..7e48d8c 100644 --- a/README.md +++ b/README.md @@ -40,11 +40,11 @@ existing approaches on different datasets. ## Setup -- torchvision -- pytorch -- numpy -- tqdm -- tensorboardX +Install dependencies using requirements.txt + +```shell +pip install -r requirements.txt +``` ## Running Models diff --git a/datasets.py b/datasets.py index da588dc..8904c25 100644 --- a/datasets.py +++ b/datasets.py @@ -26,502 +26,505 @@ class BaseDataset(torch.utils.data.Dataset): - """Base class for a dataset.""" + """Base class for a dataset.""" - def __init__(self): - super(BaseDataset, self).__init__() - self.imgs = [] - self.test_queries = [] + def __init__(self): + super(BaseDataset, self).__init__() + self.imgs = [] + self.test_queries = [] - def get_loader(self, - batch_size, - shuffle=False, - drop_last=False, - num_workers=0): - return torch.utils.data.DataLoader( - self, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - drop_last=drop_last, - collate_fn=lambda i: i) + def get_loader(self, + batch_size, + shuffle=False, + drop_last=False, + num_workers=0): + return torch.utils.data.DataLoader( + self, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + drop_last=drop_last, + collate_fn=lambda i: i) - def get_test_queries(self): - return self.test_queries + def get_test_queries(self): + return self.test_queries - def get_all_texts(self): - raise NotImplementedError + def get_all_texts(self): + raise NotImplementedError - def __getitem__(self, idx): - return self.generate_random_query_target() + def __getitem__(self, idx): + return self.generate_random_query_target() - def generate_random_query_target(self): - raise NotImplementedError + def generate_random_query_target(self): + raise NotImplementedError - def get_img(self, idx, raw_img=False): - raise NotImplementedError + def get_img(self, idx, raw_img=False): + raise NotImplementedError class CSSDataset(BaseDataset): - """CSS dataset.""" - - def __init__(self, path, split='train', transform=None): - super(CSSDataset, self).__init__() - - self.img_path = path + '/images/' - self.transform = transform - self.split = split - self.data = np.load(path + '/css_toy_dataset_novel2_small.dup.npy').item() - self.mods = self.data[self.split]['mods'] - self.imgs = [] - for objects in self.data[self.split]['objects_img']: - label = len(self.imgs) - if self.data[self.split].has_key('labels'): - label = self.data[self.split]['labels'][label] - self.imgs += [{ - 'objects': objects, - 'label': label, - 'captions': [str(label)] - }] - - self.imgid2modtarget = {} - for i in range(len(self.imgs)): - self.imgid2modtarget[i] = [] - for i, mod in enumerate(self.mods): - for k in range(len(mod['from'])): - f = mod['from'][k] - t = mod['to'][k] - self.imgid2modtarget[f] += [(i, t)] - - self.generate_test_queries_() - - def generate_test_queries_(self): - test_queries = [] - for mod in self.mods: - for i, j in zip(mod['from'], mod['to']): - test_queries += [{ - 'source_img_id': i, - 'target_caption': self.imgs[j]['captions'][0], - 'mod': { - 'str': mod['to_str'] + """CSS dataset.""" + + def __init__(self, path, split='train', transform=None): + super(CSSDataset, self).__init__() + + self.img_path = path + '/images/' + self.transform = transform + self.split = split + self.data = np.load( + path + '/css_toy_dataset_novel2_small.dup.npy').item() + self.mods = self.data[self.split]['mods'] + self.imgs = [] + for objects in self.data[self.split]['objects_img']: + label = len(self.imgs) + if 'labels' in self.data[self.split]: + label = self.data[self.split]['labels'][label] + self.imgs += [{ + 'objects': objects, + 'label': label, + 'captions': [str(label)] + }] + + self.imgid2modtarget = {} + for i in range(len(self.imgs)): + self.imgid2modtarget[i] = [] + for i, mod in enumerate(self.mods): + for k in range(len(mod['from'])): + f = mod['from'][k] + t = mod['to'][k] + self.imgid2modtarget[f] += [(i, t)] + + self.generate_test_queries_() + + def generate_test_queries_(self): + test_queries = [] + for mod in self.mods: + for i, j in zip(mod['from'], mod['to']): + test_queries += [{ + 'source_img_id': i, + 'target_caption': self.imgs[j]['captions'][0], + 'mod': { + 'str': mod['to_str'] + } + }] + self.test_queries = test_queries + + def get_1st_training_query(self): + i = np.random.randint(0, len(self.mods)) + mod = self.mods[i] + j = np.random.randint(0, len(mod['from'])) + self.last_from = mod['from'][j] + self.last_mod = [i] + return mod['from'][j], i, mod['to'][j] + + def get_2nd_training_query(self): + modid, new_to = random.choice(self.imgid2modtarget[self.last_from]) + while modid in self.last_mod: + modid, new_to = random.choice(self.imgid2modtarget[self.last_from]) + self.last_mod += [modid] + # mod = self.mods[modid] + return self.last_from, modid, new_to + + def generate_random_query_target(self): + try: + if len(self.last_mod) < 2: + img1id, modid, img2id = self.get_2nd_training_query() + else: + img1id, modid, img2id = self.get_1st_training_query() + except: + img1id, modid, img2id = self.get_1st_training_query() + + out = {} + out['source_img_id'] = img1id + out['source_img_data'] = self.get_img(img1id) + out['target_img_id'] = img2id + out['target_img_data'] = self.get_img(img2id) + out['mod'] = {'id': modid, 'str': self.mods[modid]['to_str']} + return out + + def __len__(self): + return len(self.imgs) + + def get_all_texts(self): + return [mod['to_str'] for mod in self.mods] + + def get_img(self, idx, raw_img=False, get_2d=False): + """Gets CSS images.""" + def generate_2d_image(objects): + img = np.ones((64, 64, 3)) + colortext2values = { + 'gray': [87, 87, 87], + 'red': [244, 35, 35], + 'blue': [42, 75, 215], + 'green': [29, 205, 20], + 'brown': [129, 74, 25], + 'purple': [129, 38, 192], + 'cyan': [41, 208, 208], + 'yellow': [255, 238, 51] } - }] - self.test_queries = test_queries - - def get_1st_training_query(self): - i = np.random.randint(0, len(self.mods)) - mod = self.mods[i] - j = np.random.randint(0, len(mod['from'])) - self.last_from = mod['from'][j] - self.last_mod = [i] - return mod['from'][j], i, mod['to'][j] - - def get_2nd_training_query(self): - modid, new_to = random.choice(self.imgid2modtarget[self.last_from]) - while modid in self.last_mod: - modid, new_to = random.choice(self.imgid2modtarget[self.last_from]) - self.last_mod += [modid] - # mod = self.mods[modid] - return self.last_from, modid, new_to - - def generate_random_query_target(self): - try: - if len(self.last_mod) < 2: - img1id, modid, img2id = self.get_2nd_training_query() - else: - img1id, modid, img2id = self.get_1st_training_query() - except: - img1id, modid, img2id = self.get_1st_training_query() - - out = {} - out['source_img_id'] = img1id - out['source_img_data'] = self.get_img(img1id) - out['target_img_id'] = img2id - out['target_img_data'] = self.get_img(img2id) - out['mod'] = {'id': modid, 'str': self.mods[modid]['to_str']} - return out - - def __len__(self): - return len(self.imgs) - - def get_all_texts(self): - return [mod['to_str'] for mod in self.mods] - - def get_img(self, idx, raw_img=False, get_2d=False): - """Gets CSS images.""" - def generate_2d_image(objects): - img = np.ones((64, 64, 3)) - colortext2values = { - 'gray': [87, 87, 87], - 'red': [244, 35, 35], - 'blue': [42, 75, 215], - 'green': [29, 205, 20], - 'brown': [129, 74, 25], - 'purple': [129, 38, 192], - 'cyan': [41, 208, 208], - 'yellow': [255, 238, 51] - } - for obj in objects: - s = 4.0 - if obj['size'] == 'large': - s *= 2 - c = [0, 0, 0] - for j in range(3): - c[j] = 1.0 * colortext2values[obj['color']][j] / 255.0 - y = obj['pos'][0] * img.shape[0] - x = obj['pos'][1] * img.shape[1] - if obj['shape'] == 'rectangle': - img[int(y - s):int(y + s), int(x - s):int(x + s), :] = c - if obj['shape'] == 'circle': - for y0 in range(int(y - s), int(y + s) + 1): - x0 = x + (abs(y0 - y) - s) - x1 = 2 * x - x0 - img[y0, int(x0):int(x1), :] = c - if obj['shape'] == 'triangle': - for y0 in range(int(y - s), int(y + s)): - x0 = x + (y0 - y + s) / 2 - x1 = 2 * x - x0 - x0, x1 = min(x0, x1), max(x0, x1) - img[y0, int(x0):int(x1), :] = c - return img - - if self.img_path is None or get_2d: - img = generate_2d_image(self.imgs[idx]['objects']) - else: - img_path = self.img_path + ('/css_%s_%06d.png' % (self.split, int(idx))) - with open(img_path, 'rb') as f: - img = PIL.Image.open(f) - img = img.convert('RGB') - - if raw_img: - return img - if self.transform: - img = self.transform(img) - return img + for obj in objects: + s = 4.0 + if obj['size'] == 'large': + s *= 2 + c = [0, 0, 0] + for j in range(3): + c[j] = 1.0 * colortext2values[obj['color']][j] / 255.0 + y = obj['pos'][0] * img.shape[0] + x = obj['pos'][1] * img.shape[1] + if obj['shape'] == 'rectangle': + img[int(y - s):int(y + s), int(x - s):int(x + s), :] = c + if obj['shape'] == 'circle': + for y0 in range(int(y - s), int(y + s) + 1): + x0 = x + (abs(y0 - y) - s) + x1 = 2 * x - x0 + img[y0, int(x0):int(x1), :] = c + if obj['shape'] == 'triangle': + for y0 in range(int(y - s), int(y + s)): + x0 = x + (y0 - y + s) / 2 + x1 = 2 * x - x0 + x0, x1 = min(x0, x1), max(x0, x1) + img[y0, int(x0):int(x1), :] = c + return img + + if self.img_path is None or get_2d: + img = generate_2d_image(self.imgs[idx]['objects']) + else: + img_path = self.img_path + \ + ('/css_%s_%06d.png' % (self.split, int(idx))) + with open(img_path, 'rb') as f: + img = PIL.Image.open(f) + img = img.convert('RGB') + + if raw_img: + return img + if self.transform: + img = self.transform(img) + return img class Fashion200k(BaseDataset): - """Fashion200k dataset.""" - - def __init__(self, path, split='train', transform=None): - super(Fashion200k, self).__init__() - - self.split = split - self.transform = transform - self.img_path = path + '/' - - # get label files for the split - label_path = path + '/labels/' - from os import listdir - from os.path import isfile - from os.path import join - label_files = [ - f for f in listdir(label_path) if isfile(join(label_path, f)) - ] - label_files = [f for f in label_files if split in f] - - # read image info from label files - self.imgs = [] - - def caption_post_process(s): - return s.strip().replace('.', - 'dotmark').replace('?', 'questionmark').replace( - '&', 'andmark').replace('*', 'starmark') - - for filename in label_files: - print('read ' + filename) - with open(label_path + '/' + filename) as f: - lines = f.readlines() - for line in lines: - line = line.split(' ') - img = { - 'file_path': line[0], - 'detection_score': line[1], - 'captions': [caption_post_process(line[2])], - 'split': split, - 'modifiable': False - } - self.imgs += [img] - print 'Fashion200k:', len(self.imgs), 'images' - - # generate query for training or testing - if split == 'train': - self.caption_index_init_() - else: - self.generate_test_queries_() - - def get_different_word(self, source_caption, target_caption): - source_words = source_caption.split() - target_words = target_caption.split() - for source_word in source_words: - if source_word not in target_words: - break - for target_word in target_words: - if target_word not in source_words: - break - mod_str = 'replace ' + source_word + ' with ' + target_word - return source_word, target_word, mod_str - - def generate_test_queries_(self): - file2imgid = {} - for i, img in enumerate(self.imgs): - file2imgid[img['file_path']] = i - with open(self.img_path + '/test_queries.txt') as f: - lines = f.readlines() - self.test_queries = [] - for line in lines: - source_file, target_file = line.split() - idx = file2imgid[source_file] - target_idx = file2imgid[target_file] - source_caption = self.imgs[idx]['captions'][0] - target_caption = self.imgs[target_idx]['captions'][0] - source_word, target_word, mod_str = self.get_different_word( - source_caption, target_caption) - self.test_queries += [{ - 'source_img_id': idx, - 'source_caption': source_caption, - 'target_caption': target_caption, - 'mod': { - 'str': mod_str - } - }] - - def caption_index_init_(self): - """ index caption to generate training query-target example on the fly later""" - - # index caption 2 caption_id and caption 2 image_ids - caption2id = {} - id2caption = {} - caption2imgids = {} - for i, img in enumerate(self.imgs): - for c in img['captions']: - if not caption2id.has_key(c): - id2caption[len(caption2id)] = c - caption2id[c] = len(caption2id) - caption2imgids[c] = [] - caption2imgids[c].append(i) - self.caption2imgids = caption2imgids - print len(caption2imgids), 'unique cations' - - # parent captions are 1-word shorter than their children - parent2children_captions = {} - for c in caption2id.keys(): - for w in c.split(): - p = c.replace(w, '') - p = p.replace(' ', ' ').strip() - if not parent2children_captions.has_key(p): - parent2children_captions[p] = [] - if c not in parent2children_captions[p]: - parent2children_captions[p].append(c) - self.parent2children_captions = parent2children_captions - - # identify parent captions for each image - for img in self.imgs: - img['modifiable'] = False - img['parent_captions'] = [] - for p in parent2children_captions: - if len(parent2children_captions[p]) >= 2: - for c in parent2children_captions[p]: - for imgid in caption2imgids[c]: - self.imgs[imgid]['modifiable'] = True - self.imgs[imgid]['parent_captions'] += [p] - num_modifiable_imgs = 0 - for img in self.imgs: - if img['modifiable']: - num_modifiable_imgs += 1 - print 'Modifiable images', num_modifiable_imgs - - def caption_index_sample_(self, idx): - while not self.imgs[idx]['modifiable']: - idx = np.random.randint(0, len(self.imgs)) - - # find random target image (same parent) - img = self.imgs[idx] - while True: - p = random.choice(img['parent_captions']) - c = random.choice(self.parent2children_captions[p]) - if c not in img['captions']: - break - target_idx = random.choice(self.caption2imgids[c]) - - # find the word difference between query and target (not in parent caption) - source_caption = self.imgs[idx]['captions'][0] - target_caption = self.imgs[target_idx]['captions'][0] - source_word, target_word, mod_str = self.get_different_word( - source_caption, target_caption) - return idx, target_idx, source_word, target_word, mod_str - - def get_all_texts(self): - texts = [] - for img in self.imgs: - for c in img['captions']: - texts.append(c) - return texts - - def __len__(self): - return len(self.imgs) - - def __getitem__(self, idx): - idx, target_idx, source_word, target_word, mod_str = self.caption_index_sample_( - idx) - out = {} - out['source_img_id'] = idx - out['source_img_data'] = self.get_img(idx) - out['source_caption'] = self.imgs[idx]['captions'][0] - out['target_img_id'] = target_idx - out['target_img_data'] = self.get_img(target_idx) - out['target_caption'] = self.imgs[target_idx]['captions'][0] - out['mod'] = {'str': mod_str} - return out - - def get_img(self, idx, raw_img=False): - img_path = self.img_path + self.imgs[idx]['file_path'] - with open(img_path, 'rb') as f: - img = PIL.Image.open(f) - img = img.convert('RGB') - if raw_img: - return img - if self.transform: - img = self.transform(img) - return img + """Fashion200k dataset.""" + + def __init__(self, path, split='train', transform=None): + super(Fashion200k, self).__init__() + + self.split = split + self.transform = transform + self.img_path = path + '/' + + # get label files for the split + label_path = path + '/labels/' + from os import listdir + from os.path import isfile + from os.path import join + label_files = [ + f for f in listdir(label_path) if isfile(join(label_path, f)) + ] + label_files = [f for f in label_files if split in f] + + # read image info from label files + self.imgs = [] + + def caption_post_process(s): + return s.strip().replace('.', + 'dotmark').replace('?', 'questionmark').replace( + '&', 'andmark').replace('*', 'starmark') + + for filename in label_files: + print('read ' + filename) + with open(label_path + '/' + filename) as f: + lines = f.readlines() + for line in lines: + line = line.split(' ') + img = { + 'file_path': line[0], + 'detection_score': line[1], + 'captions': [caption_post_process(line[2])], + 'split': split, + 'modifiable': False + } + self.imgs += [img] + print('Fashion200k:', len(self.imgs), 'images') + + # generate query for training or testing + if split == 'train': + self.caption_index_init_() + else: + self.generate_test_queries_() + + def get_different_word(self, source_caption, target_caption): + source_words = source_caption.split() + target_words = target_caption.split() + for source_word in source_words: + if source_word not in target_words: + break + for target_word in target_words: + if target_word not in source_words: + break + mod_str = 'replace ' + source_word + ' with ' + target_word + return source_word, target_word, mod_str + + def generate_test_queries_(self): + file2imgid = {} + for i, img in enumerate(self.imgs): + file2imgid[img['file_path']] = i + with open(self.img_path + '/test_queries.txt') as f: + lines = f.readlines() + self.test_queries = [] + for line in lines: + source_file, target_file = line.split() + idx = file2imgid[source_file] + target_idx = file2imgid[target_file] + source_caption = self.imgs[idx]['captions'][0] + target_caption = self.imgs[target_idx]['captions'][0] + source_word, target_word, mod_str = self.get_different_word( + source_caption, target_caption) + self.test_queries += [{ + 'source_img_id': idx, + 'source_caption': source_caption, + 'target_caption': target_caption, + 'mod': { + 'str': mod_str + } + }] + + def caption_index_init_(self): + """ index caption to generate training query-target example on the fly later""" + + # index caption 2 caption_id and caption 2 image_ids + caption2id = {} + id2caption = {} + caption2imgids = {} + for i, img in enumerate(self.imgs): + for c in img['captions']: + if not c in caption2id: + id2caption[len(caption2id)] = c + caption2id[c] = len(caption2id) + caption2imgids[c] = [] + caption2imgids[c].append(i) + self.caption2imgids = caption2imgids + print(len(caption2imgids), 'unique cations') + + # parent captions are 1-word shorter than their children + parent2children_captions = {} + for c in caption2id.keys(): + for w in c.split(): + p = c.replace(w, '') + p = p.replace(' ', ' ').strip() + if not (p in parent2children_captions): + parent2children_captions[p] = [] + if c not in parent2children_captions[p]: + parent2children_captions[p].append(c) + self.parent2children_captions = parent2children_captions + + # identify parent captions for each image + for img in self.imgs: + img['modifiable'] = False + img['parent_captions'] = [] + for p in parent2children_captions: + if len(parent2children_captions[p]) >= 2: + for c in parent2children_captions[p]: + for imgid in caption2imgids[c]: + self.imgs[imgid]['modifiable'] = True + self.imgs[imgid]['parent_captions'] += [p] + num_modifiable_imgs = 0 + for img in self.imgs: + if img['modifiable']: + num_modifiable_imgs += 1 + print('Modifiable images', num_modifiable_imgs) + + def caption_index_sample_(self, idx): + while not self.imgs[idx]['modifiable']: + idx = np.random.randint(0, len(self.imgs)) + + # find random target image (same parent) + img = self.imgs[idx] + while True: + p = random.choice(img['parent_captions']) + c = random.choice(self.parent2children_captions[p]) + if c not in img['captions']: + break + target_idx = random.choice(self.caption2imgids[c]) + + # find the word difference between query and target (not in parent + # caption) + source_caption = self.imgs[idx]['captions'][0] + target_caption = self.imgs[target_idx]['captions'][0] + source_word, target_word, mod_str = self.get_different_word( + source_caption, target_caption) + return idx, target_idx, source_word, target_word, mod_str + + def get_all_texts(self): + texts = [] + for img in self.imgs: + for c in img['captions']: + texts.append(c) + return texts + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, idx): + idx, target_idx, source_word, target_word, mod_str = self.caption_index_sample_( + idx) + out = {} + out['source_img_id'] = idx + out['source_img_data'] = self.get_img(idx) + out['source_caption'] = self.imgs[idx]['captions'][0] + out['target_img_id'] = target_idx + out['target_img_data'] = self.get_img(target_idx) + out['target_caption'] = self.imgs[target_idx]['captions'][0] + out['mod'] = {'str': mod_str} + return out + + def get_img(self, idx, raw_img=False): + img_path = self.img_path + self.imgs[idx]['file_path'] + with open(img_path, 'rb') as f: + img = PIL.Image.open(f) + img = img.convert('RGB') + if raw_img: + return img + if self.transform: + img = self.transform(img) + return img class MITStates(BaseDataset): - """MITStates dataset.""" - - def __init__(self, path, split='train', transform=None): - super(MITStates, self).__init__() - self.path = path - self.transform = transform - self.split = split - - self.imgs = [] - test_nouns = [ - u'armor', u'bracelet', u'bush', u'camera', u'candy', u'castle', - u'ceramic', u'cheese', u'clock', u'clothes', u'coffee', u'fan', u'fig', - u'fish', u'foam', u'forest', u'fruit', u'furniture', u'garden', u'gate', - u'glass', u'horse', u'island', u'laptop', u'lead', u'lightning', - u'mirror', u'orange', u'paint', u'persimmon', u'plastic', u'plate', - u'potato', u'road', u'rubber', u'sand', u'shell', u'sky', u'smoke', - u'steel', u'stream', u'table', u'tea', u'tomato', u'vacuum', u'wax', - u'wheel', u'window', u'wool' - ] - - from os import listdir - for f in listdir(path + '/images'): - if ' ' not in f: - continue - adj, noun = f.split() - if adj == 'adj': - continue - if split == 'train' and noun in test_nouns: - continue - if split == 'test' and noun not in test_nouns: - continue - - for file_path in listdir(path + '/images/' + f): - assert (file_path.endswith('jpg')) - self.imgs += [{ - 'file_path': path + '/images/' + f + '/' + file_path, - 'captions': [f], - 'adj': adj, - 'noun': noun - }] - - self.caption_index_init_() - if split == 'test': - self.generate_test_queries_() - - def get_all_texts(self): - texts = [] - for img in self.imgs: - texts += img['captions'] - return texts - - def __getitem__(self, idx): - try: - self.saved_item - except: - self.saved_item = None - if self.saved_item is None: - while True: - idx, target_idx1 = self.caption_index_sample_(idx) - idx, target_idx2 = self.caption_index_sample_(idx) - if self.imgs[target_idx1]['adj'] != self.imgs[target_idx2]['adj']: - break - idx, target_idx = [idx, target_idx1] - self.saved_item = [idx, target_idx2] - else: - idx, target_idx = self.saved_item - self.saved_item = None - - mod_str = self.imgs[target_idx]['adj'] - - return { - 'source_img_id': idx, - 'source_img_data': self.get_img(idx), - 'source_caption': self.imgs[idx]['captions'][0], - 'target_img_id': target_idx, - 'target_img_data': self.get_img(target_idx), - 'target_caption': self.imgs[target_idx]['captions'][0], - 'mod': { - 'str': mod_str + """MITStates dataset.""" + + def __init__(self, path, split='train', transform=None): + super(MITStates, self).__init__() + self.path = path + self.transform = transform + self.split = split + + self.imgs = [] + test_nouns = [ + u'armor', u'bracelet', u'bush', u'camera', u'candy', u'castle', + u'ceramic', u'cheese', u'clock', u'clothes', u'coffee', u'fan', u'fig', + u'fish', u'foam', u'forest', u'fruit', u'furniture', u'garden', u'gate', + u'glass', u'horse', u'island', u'laptop', u'lead', u'lightning', + u'mirror', u'orange', u'paint', u'persimmon', u'plastic', u'plate', + u'potato', u'road', u'rubber', u'sand', u'shell', u'sky', u'smoke', + u'steel', u'stream', u'table', u'tea', u'tomato', u'vacuum', u'wax', + u'wheel', u'window', u'wool' + ] + + from os import listdir + for f in listdir(path + '/images'): + if ' ' not in f: + continue + adj, noun = f.split() + if adj == 'adj': + continue + if split == 'train' and noun in test_nouns: + continue + if split == 'test' and noun not in test_nouns: + continue + + for file_path in listdir(path + '/images/' + f): + assert (file_path.endswith('jpg')) + self.imgs += [{ + 'file_path': path + '/images/' + f + '/' + file_path, + 'captions': [f], + 'adj': adj, + 'noun': noun + }] + + self.caption_index_init_() + if split == 'test': + self.generate_test_queries_() + + def get_all_texts(self): + texts = [] + for img in self.imgs: + texts += img['captions'] + return texts + + def __getitem__(self, idx): + try: + self.saved_item + except: + self.saved_item = None + if self.saved_item is None: + while True: + idx, target_idx1 = self.caption_index_sample_(idx) + idx, target_idx2 = self.caption_index_sample_(idx) + if self.imgs[target_idx1]['adj'] != self.imgs[target_idx2]['adj']: + break + idx, target_idx = [idx, target_idx1] + self.saved_item = [idx, target_idx2] + else: + idx, target_idx = self.saved_item + self.saved_item = None + + mod_str = self.imgs[target_idx]['adj'] + + return { + 'source_img_id': idx, + 'source_img_data': self.get_img(idx), + 'source_caption': self.imgs[idx]['captions'][0], + 'target_img_id': target_idx, + 'target_img_data': self.get_img(target_idx), + 'target_caption': self.imgs[target_idx]['captions'][0], + 'mod': { + 'str': mod_str + } } - } - - def caption_index_init_(self): - self.caption2imgids = {} - self.noun2adjs = {} - for i, img in enumerate(self.imgs): - cap = img['captions'][0] - adj = img['adj'] - noun = img['noun'] - if cap not in self.caption2imgids.keys(): - self.caption2imgids[cap] = [] - if noun not in self.noun2adjs.keys(): - self.noun2adjs[noun] = [] - self.caption2imgids[cap].append(i) - if adj not in self.noun2adjs[noun]: - self.noun2adjs[noun].append(adj) - for noun, adjs in self.noun2adjs.iteritems(): - assert len(adjs) >= 2 - - def caption_index_sample_(self, idx): - noun = self.imgs[idx]['noun'] - # adj = self.imgs[idx]['adj'] - target_adj = random.choice(self.noun2adjs[noun]) - target_caption = target_adj + ' ' + noun - target_idx = random.choice(self.caption2imgids[target_caption]) - return idx, target_idx - - def generate_test_queries_(self): - self.test_queries = [] - for idx, img in enumerate(self.imgs): - adj = img['adj'] - noun = img['noun'] - for target_adj in self.noun2adjs[noun]: - if target_adj != adj: - mod_str = target_adj - self.test_queries += [{ - 'source_img_id': idx, - 'source_caption': adj + ' ' + noun, - 'target_caption': target_adj + ' ' + noun, - 'mod': { - 'str': mod_str - } - }] - print len(self.test_queries), 'test queries' - - def __len__(self): - return len(self.imgs) - - def get_img(self, idx, raw_img=False): - img_path = self.imgs[idx]['file_path'] - with open(img_path, 'rb') as f: - img = PIL.Image.open(f) - img = img.convert('RGB') - if raw_img: - return img - if self.transform: - img = self.transform(img) - return img + + def caption_index_init_(self): + self.caption2imgids = {} + self.noun2adjs = {} + for i, img in enumerate(self.imgs): + cap = img['captions'][0] + adj = img['adj'] + noun = img['noun'] + if cap not in self.caption2imgids.keys(): + self.caption2imgids[cap] = [] + if noun not in self.noun2adjs.keys(): + self.noun2adjs[noun] = [] + self.caption2imgids[cap].append(i) + if adj not in self.noun2adjs[noun]: + self.noun2adjs[noun].append(adj) + for noun, adjs in self.noun2adjs.iteritems(): + assert len(adjs) >= 2 + + def caption_index_sample_(self, idx): + noun = self.imgs[idx]['noun'] + # adj = self.imgs[idx]['adj'] + target_adj = random.choice(self.noun2adjs[noun]) + target_caption = target_adj + ' ' + noun + target_idx = random.choice(self.caption2imgids[target_caption]) + return idx, target_idx + + def generate_test_queries_(self): + self.test_queries = [] + for idx, img in enumerate(self.imgs): + adj = img['adj'] + noun = img['noun'] + for target_adj in self.noun2adjs[noun]: + if target_adj != adj: + mod_str = target_adj + self.test_queries += [{ + 'source_img_id': idx, + 'source_caption': adj + ' ' + noun, + 'target_caption': target_adj + ' ' + noun, + 'mod': { + 'str': mod_str + } + }] + print(len(self.test_queries), 'test queries') + + def __len__(self): + return len(self.imgs) + + def get_img(self, idx, raw_img=False): + img_path = self.imgs[idx]['file_path'] + with open(img_path, 'rb') as f: + img = PIL.Image.open(f) + img = img.convert('RGB') + if raw_img: + return img + if self.transform: + img = self.transform(img) + return img diff --git a/img_text_composition_models.py b/img_text_composition_models.py index debbf7c..54a8b5d 100644 --- a/img_text_composition_models.py +++ b/img_text_composition_models.py @@ -25,231 +25,233 @@ class ConCatModule(torch.nn.Module): - def __init__(self): - super(ConCatModule, self).__init__() + def __init__(self): + super(ConCatModule, self).__init__() - def forward(self, x): - x = torch.cat(x, dim=1) - return x + def forward(self, x): + x = torch.cat(x, dim=1) + return x class ImgTextCompositionBase(torch.nn.Module): - """Base class for image + text composition.""" - - def __init__(self): - super(ImgTextCompositionBase, self).__init__() - self.normalization_layer = torch_functions.NormalizationLayer( - normalize_scale=4.0, learn_scale=True) - self.soft_triplet_loss = torch_functions.TripletLoss() - - def extract_img_feature(self, imgs): - raise NotImplementedError - - def extract_text_feature(self, texts): - raise NotImplementedError - - def compose_img_text(self, imgs, texts): - raise NotImplementedError - - def compute_loss(self, - imgs_query, - modification_texts, - imgs_target, - soft_triplet_loss=True): - mod_img1 = self.compose_img_text(imgs_query, modification_texts) - mod_img1 = self.normalization_layer(mod_img1) - img2 = self.extract_img_feature(imgs_target) - img2 = self.normalization_layer(img2) - assert (mod_img1.shape[0] == img2.shape[0] and - mod_img1.shape[1] == img2.shape[1]) - if soft_triplet_loss: - return self.compute_soft_triplet_loss_(mod_img1, img2) - else: - return self.compute_batch_based_classification_loss_(mod_img1, img2) - - def compute_soft_triplet_loss_(self, mod_img1, img2): - triplets = [] - labels = range(mod_img1.shape[0]) + range(img2.shape[0]) - for i in range(len(labels)): - triplets_i = [] - for j in range(len(labels)): - if labels[i] == labels[j] and i != j: - for k in range(len(labels)): - if labels[i] != labels[k]: - triplets_i.append([i, j, k]) - np.random.shuffle(triplets_i) - triplets += triplets_i[:3] - assert (triplets and len(triplets) < 2000) - return self.soft_triplet_loss(torch.cat([mod_img1, img2]), triplets) - - def compute_batch_based_classification_loss_(self, mod_img1, img2): - x = torch.mm(mod_img1, img2.transpose(0, 1)) - labels = torch.tensor(range(x.shape[0])).long() - labels = torch.autograd.Variable(labels).cuda() - return F.cross_entropy(x, labels) + """Base class for image + text composition.""" + + def __init__(self): + super(ImgTextCompositionBase, self).__init__() + self.normalization_layer = torch_functions.NormalizationLayer( + normalize_scale=4.0, learn_scale=True) + self.soft_triplet_loss = torch_functions.TripletLoss() + + def extract_img_feature(self, imgs): + raise NotImplementedError + + def extract_text_feature(self, texts): + raise NotImplementedError + + def compose_img_text(self, imgs, texts): + raise NotImplementedError + + def compute_loss(self, + imgs_query, + modification_texts, + imgs_target, + soft_triplet_loss=True): + mod_img1 = self.compose_img_text(imgs_query, modification_texts) + mod_img1 = self.normalization_layer(mod_img1) + img2 = self.extract_img_feature(imgs_target) + img2 = self.normalization_layer(img2) + assert (mod_img1.shape[0] == img2.shape[0] and + mod_img1.shape[1] == img2.shape[1]) + if soft_triplet_loss: + return self.compute_soft_triplet_loss_(mod_img1, img2) + else: + return self.compute_batch_based_classification_loss_(mod_img1, img2) + + def compute_soft_triplet_loss_(self, mod_img1, img2): + triplets = [] + labels = range(mod_img1.shape[0]) + range(img2.shape[0]) + for i in range(len(labels)): + triplets_i = [] + for j in range(len(labels)): + if labels[i] == labels[j] and i != j: + for k in range(len(labels)): + if labels[i] != labels[k]: + triplets_i.append([i, j, k]) + np.random.shuffle(triplets_i) + triplets += triplets_i[:3] + assert (triplets and len(triplets) < 2000) + return self.soft_triplet_loss(torch.cat([mod_img1, img2]), triplets) + + def compute_batch_based_classification_loss_(self, mod_img1, img2): + x = torch.mm(mod_img1, img2.transpose(0, 1)) + labels = torch.tensor(range(x.shape[0])).long() + labels = torch.autograd.Variable(labels).cuda() + return F.cross_entropy(x, labels) class ImgEncoderTextEncoderBase(ImgTextCompositionBase): - """Base class for image and text encoder.""" + """Base class for image and text encoder.""" - def __init__(self, texts, embed_dim): - super(ImgEncoderTextEncoderBase, self).__init__() + def __init__(self, texts, embed_dim): + super(ImgEncoderTextEncoderBase, self).__init__() - # img model - img_model = torchvision.models.resnet18(pretrained=True) + # img model + img_model = torchvision.models.resnet18(pretrained=True) - class GlobalAvgPool2d(torch.nn.Module): + class GlobalAvgPool2d(torch.nn.Module): - def forward(self, x): - return F.adaptive_avg_pool2d(x, (1, 1)) + def forward(self, x): + return F.adaptive_avg_pool2d(x, (1, 1)) - img_model.avgpool = GlobalAvgPool2d() - img_model.fc = torch.nn.Sequential(torch.nn.Linear(512, embed_dim)) - self.img_model = img_model + img_model.avgpool = GlobalAvgPool2d() + img_model.fc = torch.nn.Sequential(torch.nn.Linear(512, embed_dim)) + self.img_model = img_model - # text model - self.text_model = text_model.TextLSTMModel( - texts_to_build_vocab=texts, - word_embed_dim=embed_dim, - lstm_hidden_dim=embed_dim) + # text model + self.text_model = text_model.TextLSTMModel( + texts_to_build_vocab=texts, + word_embed_dim=embed_dim, + lstm_hidden_dim=embed_dim) - def extract_img_feature(self, imgs): - return self.img_model(imgs) + def extract_img_feature(self, imgs): + return self.img_model(imgs) - def extract_text_feature(self, texts): - return self.text_model(texts) + def extract_text_feature(self, texts): + return self.text_model(texts) class SimpleModelImageOnly(ImgEncoderTextEncoderBase): - def compose_img_text(self, imgs, texts): - return self.extract_img_feature(imgs) + def compose_img_text(self, imgs, texts): + return self.extract_img_feature(imgs) class SimpleModelTextOnly(ImgEncoderTextEncoderBase): - def compose_img_text(self, imgs, texts): - return self.extract_text_feature(texts) + def compose_img_text(self, imgs, texts): + return self.extract_text_feature(texts) class Concat(ImgEncoderTextEncoderBase): - """Concatenation model.""" - - def __init__(self, texts, embed_dim): - super(Concat, self).__init__(texts, embed_dim) - - # composer - class Composer(torch.nn.Module): - """Inner composer class.""" - - def __init__(self): - super(Composer, self).__init__() - self.m = torch.nn.Sequential( - torch.nn.BatchNorm1d(2 * embed_dim), torch.nn.ReLU(), - torch.nn.Linear(2 * embed_dim, 2 * embed_dim), - torch.nn.BatchNorm1d(2 * embed_dim), torch.nn.ReLU(), - torch.nn.Dropout(0.1), torch.nn.Linear(2 * embed_dim, embed_dim)) - - def forward(self, x): - f = torch.cat(x, dim=1) - f = self.m(f) - return f + """Concatenation model.""" + + def __init__(self, texts, embed_dim): + super(Concat, self).__init__(texts, embed_dim) + + # composer + class Composer(torch.nn.Module): + """Inner composer class.""" - self.composer = Composer() + def __init__(self): + super(Composer, self).__init__() + self.m = torch.nn.Sequential( + torch.nn.BatchNorm1d(2 * embed_dim), torch.nn.ReLU(), + torch.nn.Linear(2 * embed_dim, 2 * embed_dim), + torch.nn.BatchNorm1d(2 * embed_dim), torch.nn.ReLU(), + torch.nn.Dropout(0.1), torch.nn.Linear(2 * embed_dim, embed_dim)) - def compose_img_text(self, imgs, texts): - img_features = self.extract_img_feature(imgs) - text_features = self.extract_text_feature(texts) - return self.compose_img_text_features(img_features, text_features) + def forward(self, x): + f = torch.cat(x, dim=1) + f = self.m(f) + return f - def compose_img_text_features(self, img_features, text_features): - return self.composer((img_features, text_features)) + self.composer = Composer() + + def compose_img_text(self, imgs, texts): + img_features = self.extract_img_feature(imgs) + text_features = self.extract_text_feature(texts) + return self.compose_img_text_features(img_features, text_features) + + def compose_img_text_features(self, img_features, text_features): + return self.composer((img_features, text_features)) class TIRG(ImgEncoderTextEncoderBase): - """The TIGR model. - - The method is described in - Nam Vo, Lu Jiang, Chen Sun, Kevin Murphy, Li-Jia Li, Li Fei-Fei, James Hays. - "Composing Text and Image for Image Retrieval - An Empirical Odyssey" - CVPR 2019. arXiv:1812.07119 - """ - - def __init__(self, texts, embed_dim): - super(TIRG, self).__init__(texts, embed_dim) - - self.a = torch.nn.Parameter(torch.tensor([1.0, 10.0, 1.0, 1.0])) - self.gated_feature_composer = torch.nn.Sequential( - ConCatModule(), torch.nn.BatchNorm1d(2 * embed_dim), torch.nn.ReLU(), - torch.nn.Linear(2 * embed_dim, embed_dim)) - self.res_info_composer = torch.nn.Sequential( - ConCatModule(), torch.nn.BatchNorm1d(2 * embed_dim), torch.nn.ReLU(), - torch.nn.Linear(2 * embed_dim, 2 * embed_dim), torch.nn.ReLU(), - torch.nn.Linear(2 * embed_dim, embed_dim)) - - def compose_img_text(self, imgs, texts): - img_features = self.extract_img_feature(imgs) - text_features = self.extract_text_feature(texts) - return self.compose_img_text_features(img_features, text_features) - - def compose_img_text_features(self, img_features, text_features): - f1 = self.gated_feature_composer((img_features, text_features)) - f2 = self.res_info_composer((img_features, text_features)) - f = F.sigmoid(f1) * img_features * self.a[0] + f2 * self.a[1] - return f + """The TIGR model. + + The method is described in + Nam Vo, Lu Jiang, Chen Sun, Kevin Murphy, Li-Jia Li, Li Fei-Fei, James Hays. + "Composing Text and Image for Image Retrieval - An Empirical Odyssey" + CVPR 2019. arXiv:1812.07119 + """ + + def __init__(self, texts, embed_dim): + super(TIRG, self).__init__(texts, embed_dim) + + self.a = torch.nn.Parameter(torch.tensor([1.0, 10.0, 1.0, 1.0])) + self.gated_feature_composer = torch.nn.Sequential( + ConCatModule(), torch.nn.BatchNorm1d(2 * embed_dim), torch.nn.ReLU(), + torch.nn.Linear(2 * embed_dim, embed_dim)) + self.res_info_composer = torch.nn.Sequential( + ConCatModule(), torch.nn.BatchNorm1d(2 * embed_dim), torch.nn.ReLU(), + torch.nn.Linear(2 * embed_dim, 2 * embed_dim), torch.nn.ReLU(), + torch.nn.Linear(2 * embed_dim, embed_dim)) + + def compose_img_text(self, imgs, texts): + img_features = self.extract_img_feature(imgs) + text_features = self.extract_text_feature(texts) + return self.compose_img_text_features(img_features, text_features) + + def compose_img_text_features(self, img_features, text_features): + f1 = self.gated_feature_composer((img_features, text_features)) + f2 = self.res_info_composer((img_features, text_features)) + f = F.sigmoid(f1) * img_features * self.a[0] + f2 * self.a[1] + return f class TIRGLastConv(ImgEncoderTextEncoderBase): - """The TIGR model with spatial modification over the last conv layer. - - The method is described in - Nam Vo, Lu Jiang, Chen Sun, Kevin Murphy, Li-Jia Li, Li Fei-Fei, James Hays. - "Composing Text and Image for Image Retrieval - An Empirical Odyssey" - CVPR 2019. arXiv:1812.07119 - """ - - def __init__(self, texts, embed_dim): - super(TIRGLastConv, self).__init__(texts, embed_dim) - - self.a = torch.nn.Parameter(torch.tensor([1.0, 10.0, 1.0, 1.0])) - self.mod2d = torch.nn.Sequential( - torch.nn.BatchNorm2d(512 + embed_dim), - torch.nn.Conv2d(512 + embed_dim, 512 + embed_dim, [3, 3], padding=1), - torch.nn.ReLU(), - torch.nn.Conv2d(512 + embed_dim, 512, [3, 3], padding=1), - ) - - self.mod2d_gate = torch.nn.Sequential( - torch.nn.BatchNorm2d(512 + embed_dim), - torch.nn.Conv2d(512 + embed_dim, 512 + embed_dim, [3, 3], padding=1), - torch.nn.ReLU(), - torch.nn.Conv2d(512 + embed_dim, 512, [3, 3], padding=1), - ) - - def compose_img_text(self, imgs, texts): - text_features = self.extract_text_feature(texts) - - x = imgs - x = self.img_model.conv1(x) - x = self.img_model.bn1(x) - x = self.img_model.relu(x) - x = self.img_model.maxpool(x) - - x = self.img_model.layer1(x) - x = self.img_model.layer2(x) - x = self.img_model.layer3(x) - x = self.img_model.layer4(x) - - # mod - y = text_features - y = y.reshape((y.shape[0], y.shape[1], 1, 1)).repeat( - 1, 1, x.shape[2], x.shape[3]) - z = torch.cat((x, y), dim=1) - t = self.mod2d(z) - tgate = self.mod2d_gate(z) - x = self.a[0] * F.sigmoid(tgate) * x + self.a[1] * t - - x = self.img_model.avgpool(x) - x = x.view(x.size(0), -1) - x = self.img_model.fc(x) - return x + """The TIGR model with spatial modification over the last conv layer. + + The method is described in + Nam Vo, Lu Jiang, Chen Sun, Kevin Murphy, Li-Jia Li, Li Fei-Fei, James Hays. + "Composing Text and Image for Image Retrieval - An Empirical Odyssey" + CVPR 2019. arXiv:1812.07119 + """ + + def __init__(self, texts, embed_dim): + super(TIRGLastConv, self).__init__(texts, embed_dim) + + self.a = torch.nn.Parameter(torch.tensor([1.0, 10.0, 1.0, 1.0])) + self.mod2d = torch.nn.Sequential( + torch.nn.BatchNorm2d(512 + embed_dim), + torch.nn.Conv2d(512 + embed_dim, 512 + + embed_dim, [3, 3], padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(512 + embed_dim, 512, [3, 3], padding=1), + ) + + self.mod2d_gate = torch.nn.Sequential( + torch.nn.BatchNorm2d(512 + embed_dim), + torch.nn.Conv2d(512 + embed_dim, 512 + + embed_dim, [3, 3], padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(512 + embed_dim, 512, [3, 3], padding=1), + ) + + def compose_img_text(self, imgs, texts): + text_features = self.extract_text_feature(texts) + + x = imgs + x = self.img_model.conv1(x) + x = self.img_model.bn1(x) + x = self.img_model.relu(x) + x = self.img_model.maxpool(x) + + x = self.img_model.layer1(x) + x = self.img_model.layer2(x) + x = self.img_model.layer3(x) + x = self.img_model.layer4(x) + + # mod + y = text_features + y = y.reshape((y.shape[0], y.shape[1], 1, 1)).repeat( + 1, 1, x.shape[2], x.shape[3]) + z = torch.cat((x, y), dim=1) + t = self.mod2d(z) + tgate = self.mod2d_gate(z) + x = self.a[0] * F.sigmoid(tgate) * x + self.a[1] * t + + x = self.img_model.avgpool(x) + x = x.view(x.size(0), -1) + x = self.img_model.fc(x) + return x diff --git a/main.py b/main.py index 6a91b08..497d628 100644 --- a/main.py +++ b/main.py @@ -35,266 +35,268 @@ def parse_opt(): - """Parses the input arguments.""" - parser = argparse.ArgumentParser() - parser.add_argument('-f', type=str, default='') - parser.add_argument('--comment', type=str, default='test_notebook') - parser.add_argument('--dataset', type=str, default='css3d') - parser.add_argument( - '--dataset_path', type=str, default='../imgcomsearch/CSSDataset/output') - parser.add_argument('--model', type=str, default='tirg') - parser.add_argument('--embed_dim', type=int, default=512) - parser.add_argument('--learning_rate', type=float, default=1e-2) - parser.add_argument( - '--learning_rate_decay_frequency', type=int, default=9999999) - parser.add_argument('--batch_size', type=int, default=32) - parser.add_argument('--weight_decay', type=float, default=1e-6) - parser.add_argument('--num_iters', type=int, default=210000) - parser.add_argument('--loss', type=str, default='soft_triplet') - parser.add_argument('--loader_num_workers', type=int, default=4) - args = parser.parse_args() - return args + """Parses the input arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument('-f', type=str, default='') + parser.add_argument('--comment', type=str, default='test_notebook') + parser.add_argument('--dataset', type=str, default='css3d') + parser.add_argument( + '--dataset_path', type=str, default='../imgcomsearch/CSSDataset/output') + parser.add_argument('--model', type=str, default='tirg') + parser.add_argument('--embed_dim', type=int, default=512) + parser.add_argument('--learning_rate', type=float, default=1e-2) + parser.add_argument( + '--learning_rate_decay_frequency', type=int, default=9999999) + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--weight_decay', type=float, default=1e-6) + parser.add_argument('--num_iters', type=int, default=210000) + parser.add_argument('--loss', type=str, default='soft_triplet') + parser.add_argument('--loader_num_workers', type=int, default=4) + args = parser.parse_args() + return args def load_dataset(opt): - """Loads the input datasets.""" - print 'Reading dataset ', opt.dataset - if opt.dataset == 'css3d': - trainset = datasets.CSSDataset( - path=opt.dataset_path, - split='train', - transform=torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize([0.485, 0.456, 0.406], - [0.229, 0.224, 0.225]) - ])) - testset = datasets.CSSDataset( - path=opt.dataset_path, - split='test', - transform=torchvision.transforms.Compose([ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize([0.485, 0.456, 0.406], - [0.229, 0.224, 0.225]) - ])) - elif opt.dataset == 'fashion200k': - trainset = datasets.Fashion200k( - path=opt.dataset_path, - split='train', - transform=torchvision.transforms.Compose([ - torchvision.transforms.Resize(224), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize([0.485, 0.456, 0.406], - [0.229, 0.224, 0.225]) - ])) - testset = datasets.Fashion200k( - path=opt.dataset_path, - split='test', - transform=torchvision.transforms.Compose([ - torchvision.transforms.Resize(224), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize([0.485, 0.456, 0.406], - [0.229, 0.224, 0.225]) - ])) - elif opt.dataset == 'mitstates': - trainset = datasets.MITStates( - path=opt.dataset_path, - split='train', - transform=torchvision.transforms.Compose([ - torchvision.transforms.Resize(224), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize([0.485, 0.456, 0.406], - [0.229, 0.224, 0.225]) - ])) - testset = datasets.MITStates( - path=opt.dataset_path, - split='test', - transform=torchvision.transforms.Compose([ - torchvision.transforms.Resize(224), - torchvision.transforms.CenterCrop(224), - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize([0.485, 0.456, 0.406], - [0.229, 0.224, 0.225]) - ])) - else: - print 'Invalid dataset', opt.dataset - sys.exit() - - print 'trainset size:', len(trainset) - print 'testset size:', len(testset) - return trainset, testset + """Loads the input datasets.""" + print('Reading dataset ', opt.dataset) + if opt.dataset == 'css3d': + trainset = datasets.CSSDataset( + path=opt.dataset_path, + split='train', + transform=torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ])) + testset = datasets.CSSDataset( + path=opt.dataset_path, + split='test', + transform=torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ])) + elif opt.dataset == 'fashion200k': + trainset = datasets.Fashion200k( + path=opt.dataset_path, + split='train', + transform=torchvision.transforms.Compose([ + torchvision.transforms.Resize(224), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ])) + testset = datasets.Fashion200k( + path=opt.dataset_path, + split='test', + transform=torchvision.transforms.Compose([ + torchvision.transforms.Resize(224), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ])) + elif opt.dataset == 'mitstates': + trainset = datasets.MITStates( + path=opt.dataset_path, + split='train', + transform=torchvision.transforms.Compose([ + torchvision.transforms.Resize(224), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ])) + testset = datasets.MITStates( + path=opt.dataset_path, + split='test', + transform=torchvision.transforms.Compose([ + torchvision.transforms.Resize(224), + torchvision.transforms.CenterCrop(224), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ])) + else: + print('Invalid dataset', opt.dataset) + sys.exit() + + print('trainset size:', len(trainset)) + print('testset size:', len(testset)) + return trainset, testset def create_model_and_optimizer(opt, texts): - """Builds the model and related optimizer.""" - print 'Creating model and optimizer for', opt.model - if opt.model == 'imgonly': - model = img_text_composition_models.SimpleModelImageOnly( - texts, embed_dim=opt.embed_dim) - elif opt.model == 'textonly': - model = img_text_composition_models.SimpleModelTextOnly( - texts, embed_dim=opt.embed_dim) - elif opt.model == 'concat': - model = img_text_composition_models.Concat(texts, embed_dim=opt.embed_dim) - elif opt.model == 'tirg': - model = img_text_composition_models.TIRG(texts, embed_dim=opt.embed_dim) - elif opt.model == 'tirg_lastconv': - model = img_text_composition_models.TIRGLastConv( - texts, embed_dim=opt.embed_dim) - else: - print 'Invalid model', opt.model - print 'available: imgonly, textonly, concat, tirg or tirg_lastconv' - sys.exit() - model = model.cuda() - - # create optimizer - params = [] - # low learning rate for pretrained layers on real image datasets - if opt.dataset != 'css3d': - params.append({ - 'params': [p for p in model.img_model.fc.parameters()], - 'lr': opt.learning_rate - }) - params.append({ - 'params': [p for p in model.img_model.parameters()], - 'lr': 0.1 * opt.learning_rate - }) - params.append({'params': [p for p in model.parameters()]}) - for _, p1 in enumerate(params): # remove duplicated params - for _, p2 in enumerate(params): - if p1 is not p2: - for p11 in p1['params']: - for j, p22 in enumerate(p2['params']): - if p11 is p22: - p2['params'][j] = torch.tensor(0.0, requires_grad=True) - optimizer = torch.optim.SGD( - params, lr=opt.learning_rate, momentum=0.9, weight_decay=opt.weight_decay) - return model, optimizer + """Builds the model and related optimizer.""" + print('Creating model and optimizer for', opt.model) + if opt.model == 'imgonly': + model = img_text_composition_models.SimpleModelImageOnly( + texts, embed_dim=opt.embed_dim) + elif opt.model == 'textonly': + model = img_text_composition_models.SimpleModelTextOnly( + texts, embed_dim=opt.embed_dim) + elif opt.model == 'concat': + model = img_text_composition_models.Concat( + texts, embed_dim=opt.embed_dim) + elif opt.model == 'tirg': + model = img_text_composition_models.TIRG( + texts, embed_dim=opt.embed_dim) + elif opt.model == 'tirg_lastconv': + model = img_text_composition_models.TIRGLastConv( + texts, embed_dim=opt.embed_dim) + else: + print('Invalid model', opt.model) + print('available: imgonly, textonly, concat, tirg or tirg_lastconv') + sys.exit() + model = model.cuda() + + # create optimizer + params = [] + # low learning rate for pretrained layers on real image datasets + if opt.dataset != 'css3d': + params.append({ + 'params': [p for p in model.img_model.fc.parameters()], + 'lr': opt.learning_rate + }) + params.append({ + 'params': [p for p in model.img_model.parameters()], + 'lr': 0.1 * opt.learning_rate + }) + params.append({'params': [p for p in model.parameters()]}) + for _, p1 in enumerate(params): # remove duplicated params + for _, p2 in enumerate(params): + if p1 is not p2: + for p11 in p1['params']: + for j, p22 in enumerate(p2['params']): + if p11 is p22: + p2['params'][j] = torch.tensor( + 0.0, requires_grad=True) + optimizer = torch.optim.SGD( + params, lr=opt.learning_rate, momentum=0.9, weight_decay=opt.weight_decay) + return model, optimizer def train_loop(opt, logger, trainset, testset, model, optimizer): - """Function for train loop""" - print 'Begin training' - losses_tracking = {} - it = 0 - epoch = -1 - tic = time.time() - while it < opt.num_iters: - epoch += 1 - - # show/log stats - print 'It', it, 'epoch', epoch, 'Elapsed time', round(time.time() - tic, - 4), opt.comment + """Function for train loop""" + print('Begin training') + losses_tracking = {} + it = 0 + epoch = -1 tic = time.time() - for loss_name in losses_tracking: - avg_loss = np.mean(losses_tracking[loss_name][-len(trainloader):]) - print ' Loss', loss_name, round(avg_loss, 4) - logger.add_scalar(loss_name, avg_loss, it) - logger.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], it) - - # test - if epoch % 3 == 1: - tests = [] - for name, dataset in [('train', trainset), ('test', testset)]: - t = test_retrieval.test(opt, model, dataset) - tests += [(name + ' ' + metric_name, metric_value) - for metric_name, metric_value in t] - for metric_name, metric_value in tests: - logger.add_scalar(metric_name, metric_value, it) - print ' ', metric_name, round(metric_value, 4) - - # save checkpoint - torch.save({ - 'it': it, - 'opt': opt, - 'model_state_dict': model.state_dict(), - }, - logger.file_writer.get_logdir() + '/latest_checkpoint.pth') - - # run trainning for 1 epoch - model.train() - trainloader = trainset.get_loader( - batch_size=opt.batch_size, - shuffle=True, - drop_last=True, - num_workers=opt.loader_num_workers) - - - def training_1_iter(data): - assert type(data) is list - img1 = np.stack([d['source_img_data'] for d in data]) - img1 = torch.from_numpy(img1).float() - img1 = torch.autograd.Variable(img1).cuda() - img2 = np.stack([d['target_img_data'] for d in data]) - img2 = torch.from_numpy(img2).float() - img2 = torch.autograd.Variable(img2).cuda() - mods = [str(d['mod']['str']) for d in data] - mods = [t.decode('utf-8') for t in mods] - - # compute loss - losses = [] - if opt.loss == 'soft_triplet': - loss_value = model.compute_loss( - img1, mods, img2, soft_triplet_loss=True) - elif opt.loss == 'batch_based_classification': - loss_value = model.compute_loss( - img1, mods, img2, soft_triplet_loss=False) - else: - print 'Invalid loss function', opt.loss - sys.exit() - loss_name = opt.loss - loss_weight = 1.0 - losses += [(loss_name, loss_weight, loss_value)] - total_loss = sum([ - loss_weight * loss_value - for loss_name, loss_weight, loss_value in losses - ]) - assert not torch.isnan(total_loss) - losses += [('total training loss', None, total_loss)] - - # track losses - for loss_name, loss_weight, loss_value in losses: - if not losses_tracking.has_key(loss_name): - losses_tracking[loss_name] = [] - losses_tracking[loss_name].append(float(loss_value)) - - # gradient descend - optimizer.zero_grad() - total_loss.backward() - optimizer.step() - - for data in tqdm(trainloader, desc='Training for epoch ' + str(epoch)): - it += 1 - training_1_iter(data) - - # decay learing rate - if it >= opt.learning_rate_decay_frequency and it % opt.learning_rate_decay_frequency == 0: - for g in optimizer.param_groups: - g['lr'] *= 0.1 - - print 'Finished training' + while it < opt.num_iters: + epoch += 1 + + # show/log stats + print('It', it, 'epoch', epoch, 'Elapsed time', round(time.time() - tic, + 4), opt.comment) + tic = time.time() + for loss_name in losses_tracking: + avg_loss = np.mean(losses_tracking[loss_name][-len(trainloader):]) + print(' Loss', loss_name, round(avg_loss, 4)) + logger.add_scalar(loss_name, avg_loss, it) + logger.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], it) + + # test + if epoch % 3 == 1: + tests = [] + for name, dataset in [('train', trainset), ('test', testset)]: + t = test_retrieval.test(opt, model, dataset) + tests += [(name + ' ' + metric_name, metric_value) + for metric_name, metric_value in t] + for metric_name, metric_value in tests: + logger.add_scalar(metric_name, metric_value, it) + print(' ', metric_name, round(metric_value, 4)) + + # save checkpoint + torch.save({ + 'it': it, + 'opt': opt, + 'model_state_dict': model.state_dict(), + }, + logger.file_writer.get_logdir() + '/latest_checkpoint.pth') + + # run trainning for 1 epoch + model.train() + trainloader = trainset.get_loader( + batch_size=opt.batch_size, + shuffle=True, + drop_last=True, + num_workers=opt.loader_num_workers) + + def training_1_iter(data): + assert type(data) is list + img1 = np.stack([d['source_img_data'] for d in data]) + img1 = torch.from_numpy(img1).float() + img1 = torch.autograd.Variable(img1).cuda() + img2 = np.stack([d['target_img_data'] for d in data]) + img2 = torch.from_numpy(img2).float() + img2 = torch.autograd.Variable(img2).cuda() + mods = [str(d['mod']['str']) for d in data] + mods = [t.decode('utf-8') for t in mods] + + # compute loss + losses = [] + if opt.loss == 'soft_triplet': + loss_value = model.compute_loss( + img1, mods, img2, soft_triplet_loss=True) + elif opt.loss == 'batch_based_classification': + loss_value = model.compute_loss( + img1, mods, img2, soft_triplet_loss=False) + else: + print('Invalid loss function', opt.loss) + sys.exit() + loss_name = opt.loss + loss_weight = 1.0 + losses += [(loss_name, loss_weight, loss_value)] + total_loss = sum([ + loss_weight * loss_value + for loss_name, loss_weight, loss_value in losses + ]) + assert not torch.isnan(total_loss) + losses += [('total training loss', None, total_loss)] + + # track losses + for loss_name, loss_weight, loss_value in losses: + if not (loss_name in losses_tracking): + losses_tracking[loss_name] = [] + losses_tracking[loss_name].append(float(loss_value)) + + # gradient descend + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + + for data in tqdm(trainloader, desc='Training for epoch ' + str(epoch)): + it += 1 + training_1_iter(data) + + # decay learing rate + if it >= opt.learning_rate_decay_frequency and it % opt.learning_rate_decay_frequency == 0: + for g in optimizer.param_groups: + g['lr'] *= 0.1 + + print('Finished training') def main(): - opt = parse_opt() - print 'Arguments:' - for k in opt.__dict__.keys(): - print ' ', k, ':', str(opt.__dict__[k]) + opt = parse_opt() + print('Arguments:') + for k in opt.__dict__.keys(): + print(' ', k, ':', str(opt.__dict__[k])) - logger = SummaryWriter(comment=opt.comment) - print 'Log files saved to', logger.file_writer.get_logdir() - for k in opt.__dict__.keys(): - logger.add_text(k, str(opt.__dict__[k])) + logger = SummaryWriter(comment=opt.comment) + print('Log files saved to', logger.file_writer.get_logdir()) + for k in opt.__dict__.keys(): + logger.add_text(k, str(opt.__dict__[k])) - trainset, testset = load_dataset(opt) - model, optimizer = create_model_and_optimizer( - opt, [t.decode('utf-8') for t in trainset.get_all_texts()]) + trainset, testset = load_dataset(opt) + model, optimizer = create_model_and_optimizer( + opt, [t.decode('utf-8') for t in trainset.get_all_texts()]) - train_loop(opt, logger, trainset, testset, model, optimizer) - logger.close() + train_loop(opt, logger, trainset, testset, model, optimizer) + logger.close() if __name__ == '__main__': - main() + main() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..33a723c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +torch==0.4.1 +tqdm==4.29.1 +torchvision==0.2.1 +numpy==1.17.4 +Pillow==7.0.0 +skimage==0.0 +tensorboardX==2.0 diff --git a/test_retrieval.py b/test_retrieval.py index 60fc233..0086ec9 100644 --- a/test_retrieval.py +++ b/test_retrieval.py @@ -20,115 +20,117 @@ def test(opt, model, testset): - """Tests a model over the given testset.""" - model.eval() - test_queries = testset.get_test_queries() + """Tests a model over the given testset.""" + model.eval() + test_queries = testset.get_test_queries() - all_imgs = [] - all_captions = [] - all_queries = [] - all_target_captions = [] - if test_queries: - # compute test query features - imgs = [] - mods = [] - for t in tqdm(test_queries): - imgs += [testset.get_img(t['source_img_id'])] - mods += [t['mod']['str']] - if len(imgs) >= opt.batch_size or t is test_queries[-1]: - if 'torch' not in str(type(imgs[0])): - imgs = [torch.from_numpy(d).float() for d in imgs] - imgs = torch.stack(imgs).float() - imgs = torch.autograd.Variable(imgs).cuda() - mods = [t.decode('utf-8') for t in mods] - f = model.compose_img_text(imgs, mods).data.cpu().numpy() - all_queries += [f] + all_imgs = [] + all_captions = [] + all_queries = [] + all_target_captions = [] + if test_queries: + # compute test query features imgs = [] mods = [] - all_queries = np.concatenate(all_queries) - all_target_captions = [t['target_caption'] for t in test_queries] + for t in tqdm(test_queries): + imgs += [testset.get_img(t['source_img_id'])] + mods += [t['mod']['str']] + if len(imgs) >= opt.batch_size or t is test_queries[-1]: + if 'torch' not in str(type(imgs[0])): + imgs = [torch.from_numpy(d).float() for d in imgs] + imgs = torch.stack(imgs).float() + imgs = torch.autograd.Variable(imgs).cuda() + mods = [t.decode('utf-8') for t in mods] + f = model.compose_img_text(imgs, mods).data.cpu().numpy() + all_queries += [f] + imgs = [] + mods = [] + all_queries = np.concatenate(all_queries) + all_target_captions = [t['target_caption'] for t in test_queries] - # compute all image features - imgs = [] - for i in tqdm(range(len(testset.imgs))): - imgs += [testset.get_img(i)] - if len(imgs) >= opt.batch_size or i == len(testset.imgs) - 1: - if 'torch' not in str(type(imgs[0])): - imgs = [torch.from_numpy(d).float() for d in imgs] - imgs = torch.stack(imgs).float() - imgs = torch.autograd.Variable(imgs).cuda() - imgs = model.extract_img_feature(imgs).data.cpu().numpy() - all_imgs += [imgs] + # compute all image features imgs = [] - all_imgs = np.concatenate(all_imgs) - all_captions = [img['captions'][0] for img in testset.imgs] + for i in tqdm(range(len(testset.imgs))): + imgs += [testset.get_img(i)] + if len(imgs) >= opt.batch_size or i == len(testset.imgs) - 1: + if 'torch' not in str(type(imgs[0])): + imgs = [torch.from_numpy(d).float() for d in imgs] + imgs = torch.stack(imgs).float() + imgs = torch.autograd.Variable(imgs).cuda() + imgs = model.extract_img_feature(imgs).data.cpu().numpy() + all_imgs += [imgs] + imgs = [] + all_imgs = np.concatenate(all_imgs) + all_captions = [img['captions'][0] for img in testset.imgs] - else: - # use training queries to approximate training retrieval performance - imgs0 = [] - imgs = [] - mods = [] - for i in range(10000): - item = testset[i] - imgs += [item['source_img_data']] - mods += [item['mod']['str']] - if len(imgs) > opt.batch_size or i == 9999: - imgs = torch.stack(imgs).float() - imgs = torch.autograd.Variable(imgs) - mods = [t.decode('utf-8') for t in mods] - f = model.compose_img_text(imgs.cuda(), mods).data.cpu().numpy() - all_queries += [f] + else: + # use training queries to approximate training retrieval performance + imgs0 = [] imgs = [] mods = [] - imgs0 += [item['target_img_data']] - if len(imgs0) > opt.batch_size or i == 9999: - imgs0 = torch.stack(imgs0).float() - imgs0 = torch.autograd.Variable(imgs0) - imgs0 = model.extract_img_feature(imgs0.cuda()).data.cpu().numpy() - all_imgs += [imgs0] - imgs0 = [] - all_captions += [item['target_caption']] - all_target_captions += [item['target_caption']] - all_imgs = np.concatenate(all_imgs) - all_queries = np.concatenate(all_queries) + for i in range(10000): + item = testset[i] + imgs += [item['source_img_data']] + mods += [item['mod']['str']] + if len(imgs) > opt.batch_size or i == 9999: + imgs = torch.stack(imgs).float() + imgs = torch.autograd.Variable(imgs) + mods = [t.decode('utf-8') for t in mods] + f = model.compose_img_text( + imgs.cuda(), mods).data.cpu().numpy() + all_queries += [f] + imgs = [] + mods = [] + imgs0 += [item['target_img_data']] + if len(imgs0) > opt.batch_size or i == 9999: + imgs0 = torch.stack(imgs0).float() + imgs0 = torch.autograd.Variable(imgs0) + imgs0 = model.extract_img_feature( + imgs0.cuda()).data.cpu().numpy() + all_imgs += [imgs0] + imgs0 = [] + all_captions += [item['target_caption']] + all_target_captions += [item['target_caption']] + all_imgs = np.concatenate(all_imgs) + all_queries = np.concatenate(all_queries) - # feature normalization - for i in range(all_queries.shape[0]): - all_queries[i, :] /= np.linalg.norm(all_queries[i, :]) - for i in range(all_imgs.shape[0]): - all_imgs[i, :] /= np.linalg.norm(all_imgs[i, :]) + # feature normalization + for i in range(all_queries.shape[0]): + all_queries[i, :] /= np.linalg.norm(all_queries[i, :]) + for i in range(all_imgs.shape[0]): + all_imgs[i, :] /= np.linalg.norm(all_imgs[i, :]) - # match test queries to target images, get nearest neighbors - sims = all_queries.dot(all_imgs.T) - if test_queries: - for i, t in enumerate(test_queries): - sims[i, t['source_img_id']] = -10e10 # remove query image - nn_result = [np.argsort(-sims[i, :])[:110] for i in range(sims.shape[0])] + # match test queries to target images, get nearest neighbors + sims = all_queries.dot(all_imgs.T) + if test_queries: + for i, t in enumerate(test_queries): + sims[i, t['source_img_id']] = -10e10 # remove query image + nn_result = [np.argsort(-sims[i, :])[:110] for i in range(sims.shape[0])] - # compute recalls - out = [] - nn_result = [[all_captions[nn] for nn in nns] for nns in nn_result] - for k in [1, 5, 10, 50, 100]: - r = 0.0 - for i, nns in enumerate(nn_result): - if all_target_captions[i] in nns[:k]: - r += 1 - r /= len(nn_result) - out += [('recall_top' + str(k) + '_correct_composition', r)] + # compute recalls + out = [] + nn_result = [[all_captions[nn] for nn in nns] for nns in nn_result] + for k in [1, 5, 10, 50, 100]: + r = 0.0 + for i, nns in enumerate(nn_result): + if all_target_captions[i] in nns[:k]: + r += 1 + r /= len(nn_result) + out += [('recall_top' + str(k) + '_correct_composition', r)] - if opt.dataset == 'mitstates': - r = 0.0 - for i, nns in enumerate(nn_result): - if all_target_captions[i].split()[0] in [c.split()[0] for c in nns[:k]]: - r += 1 - r /= len(nn_result) - out += [('recall_top' + str(k) + '_correct_adj', r)] + if opt.dataset == 'mitstates': + r = 0.0 + for i, nns in enumerate(nn_result): + if all_target_captions[i].split()[0] in [c.split()[0] for c in nns[:k]]: + r += 1 + r /= len(nn_result) + out += [('recall_top' + str(k) + '_correct_adj', r)] - r = 0.0 - for i, nns in enumerate(nn_result): - if all_target_captions[i].split()[1] in [c.split()[1] for c in nns[:k]]: - r += 1 - r /= len(nn_result) - out += [('recall_top' + str(k) + '_correct_noun', r)] + r = 0.0 + for i, nns in enumerate(nn_result): + if all_target_captions[i].split()[1] in [c.split()[1] for c in nns[:k]]: + r += 1 + r /= len(nn_result) + out += [('recall_top' + str(k) + '_correct_noun', r)] - return out + return out diff --git a/text_model.py b/text_model.py index dbe3c6a..c979285 100644 --- a/text_model.py +++ b/text_model.py @@ -21,103 +21,103 @@ class SimpleVocab(object): - def __init__(self): - super(SimpleVocab, self).__init__() - self.word2id = {} - self.wordcount = {} - self.word2id[''] = 0 - self.wordcount[''] = 9e9 - - def tokenize_text(self, text): - text = text.encode('ascii', 'ignore').decode('ascii') - tokens = str(text).lower().translate(None, - string.punctuation).strip().split() - return tokens - - def add_text_to_vocab(self, text): - tokens = self.tokenize_text(text) - for token in tokens: - if not self.word2id.has_key(token): - self.word2id[token] = len(self.word2id) - self.wordcount[token] = 0 - self.wordcount[token] += 1 - - def threshold_rare_words(self, wordcount_threshold=5): - for w in self.word2id: - if self.wordcount[w] < wordcount_threshold: - self.word2id[w] = 0 - - def encode_text(self, text): - tokens = self.tokenize_text(text) - x = [self.word2id.get(t, 0) for t in tokens] - return x - - def get_size(self): - return len(self.word2id) + def __init__(self): + super(SimpleVocab, self).__init__() + self.word2id = {} + self.wordcount = {} + self.word2id[''] = 0 + self.wordcount[''] = 9e9 + + def tokenize_text(self, text): + text = text.encode('ascii', 'ignore').decode('ascii') + tokens = str(text).lower().translate(None, + string.punctuation).strip().split() + return tokens + + def add_text_to_vocab(self, text): + tokens = self.tokenize_text(text) + for token in tokens: + if not (token in self.word2id): + self.word2id[token] = len(self.word2id) + self.wordcount[token] = 0 + self.wordcount[token] += 1 + + def threshold_rare_words(self, wordcount_threshold=5): + for w in self.word2id: + if self.wordcount[w] < wordcount_threshold: + self.word2id[w] = 0 + + def encode_text(self, text): + tokens = self.tokenize_text(text) + x = [self.word2id.get(t, 0) for t in tokens] + return x + + def get_size(self): + return len(self.word2id) class TextLSTMModel(torch.nn.Module): - def __init__(self, - texts_to_build_vocab, - word_embed_dim=512, - lstm_hidden_dim=512): - - super(TextLSTMModel, self).__init__() - - self.vocab = SimpleVocab() - for text in texts_to_build_vocab: - self.vocab.add_text_to_vocab(text) - vocab_size = self.vocab.get_size() - - self.word_embed_dim = word_embed_dim - self.lstm_hidden_dim = lstm_hidden_dim - self.embedding_layer = torch.nn.Embedding(vocab_size, word_embed_dim) - self.lstm = torch.nn.LSTM(word_embed_dim, lstm_hidden_dim) - self.fc_output = torch.nn.Sequential( - torch.nn.Dropout(p=0.1), - torch.nn.Linear(lstm_hidden_dim, lstm_hidden_dim), - ) - - def forward(self, x): - """ input x: list of strings""" - if type(x) is list: - if type(x[0]) is str or type(x[0]) is unicode: - x = [self.vocab.encode_text(text) for text in x] - - assert type(x) is list - assert type(x[0]) is list - assert type(x[0][0]) is int - return self.forward_encoded_texts(x) - - def forward_encoded_texts(self, texts): - # to tensor - lengths = [len(t) for t in texts] - itexts = torch.zeros((np.max(lengths), len(texts))).long() - for i in range(len(texts)): - itexts[:lengths[i], i] = torch.tensor(texts[i]) - - # embed words - itexts = torch.autograd.Variable(itexts).cuda() - etexts = self.embedding_layer(itexts) - - # lstm - lstm_output, _ = self.forward_lstm_(etexts) - - # get last output (using length) - text_features = [] - for i in range(len(texts)): - text_features.append(lstm_output[lengths[i] - 1, i, :]) - - # output - text_features = torch.stack(text_features) - text_features = self.fc_output(text_features) - return text_features - - def forward_lstm_(self, etexts): - batch_size = etexts.shape[1] - first_hidden = (torch.zeros(1, batch_size, self.lstm_hidden_dim), - torch.zeros(1, batch_size, self.lstm_hidden_dim)) - first_hidden = (first_hidden[0].cuda(), first_hidden[1].cuda()) - lstm_output, last_hidden = self.lstm(etexts, first_hidden) - return lstm_output, last_hidden + def __init__(self, + texts_to_build_vocab, + word_embed_dim=512, + lstm_hidden_dim=512): + + super(TextLSTMModel, self).__init__() + + self.vocab = SimpleVocab() + for text in texts_to_build_vocab: + self.vocab.add_text_to_vocab(text) + vocab_size = self.vocab.get_size() + + self.word_embed_dim = word_embed_dim + self.lstm_hidden_dim = lstm_hidden_dim + self.embedding_layer = torch.nn.Embedding(vocab_size, word_embed_dim) + self.lstm = torch.nn.LSTM(word_embed_dim, lstm_hidden_dim) + self.fc_output = torch.nn.Sequential( + torch.nn.Dropout(p=0.1), + torch.nn.Linear(lstm_hidden_dim, lstm_hidden_dim), + ) + + def forward(self, x): + """ input x: list of strings""" + if type(x) is list: + if isinstance(x[0], str): + x = [self.vocab.encode_text(text) for text in x] + + assert type(x) is list + assert type(x[0]) is list + assert type(x[0][0]) is int + return self.forward_encoded_texts(x) + + def forward_encoded_texts(self, texts): + # to tensor + lengths = [len(t) for t in texts] + itexts = torch.zeros((np.max(lengths), len(texts))).long() + for i in range(len(texts)): + itexts[:lengths[i], i] = torch.tensor(texts[i]) + + # embed words + itexts = torch.autograd.Variable(itexts).cuda() + etexts = self.embedding_layer(itexts) + + # lstm + lstm_output, _ = self.forward_lstm_(etexts) + + # get last output (using length) + text_features = [] + for i in range(len(texts)): + text_features.append(lstm_output[lengths[i] - 1, i, :]) + + # output + text_features = torch.stack(text_features) + text_features = self.fc_output(text_features) + return text_features + + def forward_lstm_(self, etexts): + batch_size = etexts.shape[1] + first_hidden = (torch.zeros(1, batch_size, self.lstm_hidden_dim), + torch.zeros(1, batch_size, self.lstm_hidden_dim)) + first_hidden = (first_hidden[0].cuda(), first_hidden[1].cuda()) + lstm_output, last_hidden = self.lstm(etexts, first_hidden) + return lstm_output, last_hidden diff --git a/third_party/torch_functions.py b/third_party/torch_functions.py index 80f8068..3c271e9 100644 --- a/third_party/torch_functions.py +++ b/third_party/torch_functions.py @@ -22,103 +22,106 @@ def pairwise_distances(x, y=None): - """ - Input: x is a Nxd matrix - y is an optional Mxd matirx - Output: dist is a NxM matrix where dist[i,j] is the square norm between - x[i,:] and y[j,:] - if y is not given then use 'y=x'. - i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 - source: - https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2 """ - x_norm = (x**2).sum(1).view(-1, 1) - if y is not None: - y_t = torch.transpose(y, 0, 1) - y_norm = (y**2).sum(1).view(1, -1) - else: - y_t = torch.transpose(x, 0, 1) - y_norm = x_norm.view(1, -1) - - dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) - # Ensure diagonal is zero if x=y - # if y is None: - # dist = dist - torch.diag(dist.diag) - return torch.clamp(dist, 0.0, np.inf) + Input: x is a Nxd matrix + y is an optional Mxd matirx + Output: dist is a NxM matrix where dist[i,j] is the square norm between + x[i,:] and y[j,:] + if y is not given then use 'y=x'. + i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 + source: + https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2 + """ + x_norm = (x**2).sum(1).view(-1, 1) + if y is not None: + y_t = torch.transpose(y, 0, 1) + y_norm = (y**2).sum(1).view(1, -1) + else: + y_t = torch.transpose(x, 0, 1) + y_norm = x_norm.view(1, -1) + + dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) + # Ensure diagonal is zero if x=y + # if y is None: + # dist = dist - torch.diag(dist.diag) + return torch.clamp(dist, 0.0, np.inf) class MyTripletLossFunc(torch.autograd.Function): - def __init__(self, triplets): - super(MyTripletLossFunc, self).__init__() - self.triplets = triplets - self.triplet_count = len(triplets) - - def forward(self, features): - self.save_for_backward(features) - - self.distances = pairwise_distances(features).cpu().numpy() - - loss = 0.0 - triplet_count = 0.0 - correct_count = 0.0 - for i, j, k in self.triplets: - w = 1.0 - triplet_count += w - loss += w * np.log(1 + - np.exp(self.distances[i, j] - self.distances[i, k])) - if self.distances[i, j] < self.distances[i, k]: - correct_count += 1 - - loss /= triplet_count - return torch.FloatTensor((loss,)) - - def backward(self, grad_output): - features, = self.saved_tensors - features_np = features.cpu().numpy() - grad_features = features.clone() * 0.0 - grad_features_np = grad_features.cpu().numpy() - - for i, j, k in self.triplets: - w = 1.0 - f = 1.0 - 1.0 / ( - 1.0 + np.exp(self.distances[i, j] - self.distances[i, k])) - grad_features_np[i, :] += w * f * ( - features_np[i, :] - features_np[j, :]) / self.triplet_count - grad_features_np[j, :] += w * f * ( - features_np[j, :] - features_np[i, :]) / self.triplet_count - grad_features_np[i, :] += -w * f * ( - features_np[i, :] - features_np[k, :]) / self.triplet_count - grad_features_np[k, :] += -w * f * ( - features_np[k, :] - features_np[i, :]) / self.triplet_count - - for i in range(features_np.shape[0]): - grad_features[i, :] = torch.from_numpy(grad_features_np[i, :]) - grad_features *= float(grad_output.data[0]) - return grad_features + def __init__(self, triplets): + super(MyTripletLossFunc, self).__init__() + self.triplets = triplets + self.triplet_count = len(triplets) + + def forward(self, features): + self.save_for_backward(features) + + self.distances = pairwise_distances(features).cpu().numpy() + + loss = 0.0 + triplet_count = 0.0 + correct_count = 0.0 + for i, j, k in self.triplets: + w = 1.0 + triplet_count += w + loss += w * np.log(1 + + np.exp(self.distances[i, j] - self.distances[i, k])) + if self.distances[i, j] < self.distances[i, k]: + correct_count += 1 + + loss /= triplet_count + return torch.FloatTensor((loss,)) + + def backward(self, grad_output): + features, = self.saved_tensors + features_np = features.cpu().numpy() + grad_features = features.clone() * 0.0 + grad_features_np = grad_features.cpu().numpy() + + for i, j, k in self.triplets: + w = 1.0 + f = 1.0 - 1.0 / ( + 1.0 + np.exp(self.distances[i, j] - self.distances[i, k])) + grad_features_np[i, :] += w * f * ( + features_np[i, :] - features_np[j, :]) / self.triplet_count + grad_features_np[j, :] += w * f * ( + features_np[j, :] - features_np[i, :]) / self.triplet_count + grad_features_np[i, :] += -w * f * ( + features_np[i, :] - features_np[k, :]) / self.triplet_count + grad_features_np[k, :] += -w * f * ( + features_np[k, :] - features_np[i, :]) / self.triplet_count + + for i in range(features_np.shape[0]): + grad_features[i, :] = torch.from_numpy(grad_features_np[i, :]) + grad_features *= float(grad_output.data[0]) + return grad_features class TripletLoss(torch.nn.Module): - """Class for the triplet loss.""" - def __init__(self, pre_layer=None): - super(TripletLoss, self).__init__() - self.pre_layer = pre_layer + """Class for the triplet loss.""" + + def __init__(self, pre_layer=None): + super(TripletLoss, self).__init__() + self.pre_layer = pre_layer - def forward(self, x, triplets): - if self.pre_layer is not None: - x = self.pre_layer(x) - loss = MyTripletLossFunc(triplets)(x) - return loss + def forward(self, x, triplets): + if self.pre_layer is not None: + x = self.pre_layer(x) + loss = MyTripletLossFunc(triplets)(x) + return loss class NormalizationLayer(torch.nn.Module): - """Class for normalization layer.""" - def __init__(self, normalize_scale=1.0, learn_scale=True): - super(NormalizationLayer, self).__init__() - self.norm_s = float(normalize_scale) - if learn_scale: - self.norm_s = torch.nn.Parameter(torch.FloatTensor((self.norm_s,))) - - def forward(self, x): - features = self.norm_s * x / torch.norm(x, dim=1, keepdim=True).expand_as(x) - return features + """Class for normalization layer.""" + + def __init__(self, normalize_scale=1.0, learn_scale=True): + super(NormalizationLayer, self).__init__() + self.norm_s = float(normalize_scale) + if learn_scale: + self.norm_s = torch.nn.Parameter(torch.FloatTensor((self.norm_s,))) + + def forward(self, x): + features = self.norm_s * x / \ + torch.norm(x, dim=1, keepdim=True).expand_as(x) + return features diff --git a/torch_functions.py b/torch_functions.py index 86fa209..dd2b41e 100644 --- a/torch_functions.py +++ b/torch_functions.py @@ -26,103 +26,106 @@ def pairwise_distances(x, y=None): - """ - Input: x is a Nxd matrix - y is an optional Mxd matirx - Output: dist is a NxM matrix where dist[i,j] is the square norm between - x[i,:] and y[j,:] - if y is not given then use 'y=x'. - i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 - source: - https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2 """ - x_norm = (x**2).sum(1).view(-1, 1) - if y is not None: - y_t = torch.transpose(y, 0, 1) - y_norm = (y**2).sum(1).view(1, -1) - else: - y_t = torch.transpose(x, 0, 1) - y_norm = x_norm.view(1, -1) - - dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) - # Ensure diagonal is zero if x=y - # if y is None: - # dist = dist - torch.diag(dist.diag) - return torch.clamp(dist, 0.0, np.inf) + Input: x is a Nxd matrix + y is an optional Mxd matirx + Output: dist is a NxM matrix where dist[i,j] is the square norm between + x[i,:] and y[j,:] + if y is not given then use 'y=x'. + i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 + source: + https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2 + """ + x_norm = (x**2).sum(1).view(-1, 1) + if y is not None: + y_t = torch.transpose(y, 0, 1) + y_norm = (y**2).sum(1).view(1, -1) + else: + y_t = torch.transpose(x, 0, 1) + y_norm = x_norm.view(1, -1) + + dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) + # Ensure diagonal is zero if x=y + # if y is None: + # dist = dist - torch.diag(dist.diag) + return torch.clamp(dist, 0.0, np.inf) class MyTripletLossFunc(torch.autograd.Function): - def __init__(self, triplets): - super(MyTripletLossFunc, self).__init__() - self.triplets = triplets - self.triplet_count = len(triplets) - - def forward(self, features): - self.save_for_backward(features) - - self.distances = pairwise_distances(features).cpu().numpy() - - loss = 0.0 - triplet_count = 0.0 - correct_count = 0.0 - for i, j, k in self.triplets: - w = 1.0 - triplet_count += w - loss += w * np.log(1 + - np.exp(self.distances[i, j] - self.distances[i, k])) - if self.distances[i, j] < self.distances[i, k]: - correct_count += 1 - - loss /= triplet_count - return torch.FloatTensor((loss,)) - - def backward(self, grad_output): - features, = self.saved_tensors - features_np = features.cpu().numpy() - grad_features = features.clone() * 0.0 - grad_features_np = grad_features.cpu().numpy() - - for i, j, k in self.triplets: - w = 1.0 - f = 1.0 - 1.0 / ( - 1.0 + np.exp(self.distances[i, j] - self.distances[i, k])) - grad_features_np[i, :] += w * f * ( - features_np[i, :] - features_np[j, :]) / self.triplet_count - grad_features_np[j, :] += w * f * ( - features_np[j, :] - features_np[i, :]) / self.triplet_count - grad_features_np[i, :] += -w * f * ( - features_np[i, :] - features_np[k, :]) / self.triplet_count - grad_features_np[k, :] += -w * f * ( - features_np[k, :] - features_np[i, :]) / self.triplet_count - - for i in range(features_np.shape[0]): - grad_features[i, :] = torch.from_numpy(grad_features_np[i, :]) - grad_features *= float(grad_output.data[0]) - return grad_features + def __init__(self, triplets): + super(MyTripletLossFunc, self).__init__() + self.triplets = triplets + self.triplet_count = len(triplets) + + def forward(self, features): + self.save_for_backward(features) + + self.distances = pairwise_distances(features).cpu().numpy() + + loss = 0.0 + triplet_count = 0.0 + correct_count = 0.0 + for i, j, k in self.triplets: + w = 1.0 + triplet_count += w + loss += w * np.log(1 + + np.exp(self.distances[i, j] - self.distances[i, k])) + if self.distances[i, j] < self.distances[i, k]: + correct_count += 1 + + loss /= triplet_count + return torch.FloatTensor((loss,)) + + def backward(self, grad_output): + features, = self.saved_tensors + features_np = features.cpu().numpy() + grad_features = features.clone() * 0.0 + grad_features_np = grad_features.cpu().numpy() + + for i, j, k in self.triplets: + w = 1.0 + f = 1.0 - 1.0 / ( + 1.0 + np.exp(self.distances[i, j] - self.distances[i, k])) + grad_features_np[i, :] += w * f * ( + features_np[i, :] - features_np[j, :]) / self.triplet_count + grad_features_np[j, :] += w * f * ( + features_np[j, :] - features_np[i, :]) / self.triplet_count + grad_features_np[i, :] += -w * f * ( + features_np[i, :] - features_np[k, :]) / self.triplet_count + grad_features_np[k, :] += -w * f * ( + features_np[k, :] - features_np[i, :]) / self.triplet_count + + for i in range(features_np.shape[0]): + grad_features[i, :] = torch.from_numpy(grad_features_np[i, :]) + grad_features *= float(grad_output.data[0]) + return grad_features class TripletLoss(torch.nn.Module): - """Class for the triplet loss.""" - def __init__(self, pre_layer=None): - super(TripletLoss, self).__init__() - self.pre_layer = pre_layer + """Class for the triplet loss.""" + + def __init__(self, pre_layer=None): + super(TripletLoss, self).__init__() + self.pre_layer = pre_layer - def forward(self, x, triplets): - if self.pre_layer is not None: - x = self.pre_layer(x) - loss = MyTripletLossFunc(triplets)(x) - return loss + def forward(self, x, triplets): + if self.pre_layer is not None: + x = self.pre_layer(x) + loss = MyTripletLossFunc(triplets)(x) + return loss class NormalizationLayer(torch.nn.Module): - """Class for normalization layer.""" - def __init__(self, normalize_scale=1.0, learn_scale=True): - super(NormalizationLayer, self).__init__() - self.norm_s = float(normalize_scale) - if learn_scale: - self.norm_s = torch.nn.Parameter(torch.FloatTensor((self.norm_s,))) - - def forward(self, x): - features = self.norm_s * x / torch.norm(x, dim=1, keepdim=True).expand_as(x) - return features + """Class for normalization layer.""" + + def __init__(self, normalize_scale=1.0, learn_scale=True): + super(NormalizationLayer, self).__init__() + self.norm_s = float(normalize_scale) + if learn_scale: + self.norm_s = torch.nn.Parameter(torch.FloatTensor((self.norm_s,))) + + def forward(self, x): + features = self.norm_s * x / \ + torch.norm(x, dim=1, keepdim=True).expand_as(x) + return features