-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata_engine.py
More file actions
40 lines (34 loc) · 1.4 KB
/
data_engine.py
File metadata and controls
40 lines (34 loc) · 1.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from data.msvd import MSVD
from data.msr_vtt import MSRVTT
import cPickle as pickle
import os
class DataEngine(object):
def __init__(self):
pass
def msvd(self):
msvd_file = './msvd.pkl'
if os.path.isfile(msvd_file):
msvd = pickle.load(open(msvd_file, 'r'))
return msvd
msvd_csv_path = '/home/sensetime/data/msvd/MSR_Video_Description_Corpus.csv'
msvd_video_name2id_map = '/home/sensetime/data/msvd/youtube2text_iccv15/dict_youtube_mapping.pkl'
msvd_feature_path = '/home/sensetime/data/msvd/npy2'
max_words = 30
msvd = MSVD(msvd_csv_path, msvd_video_name2id_map, msvd_feature_path, max_words)
# print len(msvd.captions['train']), len(msvd.captions['test']), len(msvd.captions['val'])
pickle.dump(msvd, open(msvd_file, 'w'))
return msvd
def msr_vtt(self):
msrvtt_file = './msrvtt.pkl'
if os.path.isfile(msrvtt_file):
msvd = pickle.load(open(msrvtt_file, 'r'))
return msvd
msr_vtt_json_path = '/home/sensetime/data/msr-vtt/videodatainfo_2017.json'
msr_vtt_feature_path = '/home/sensetime/data/msr-vtt/npy'
max_words = 30
msrvtt = MSRVTT(msr_vtt_json_path, msr_vtt_feature_path, max_words)
pickle.dump(msrvtt, open(msrvtt_file, 'w'))
return msrvtt
if __name__ == '__main__':
engine = DataEngine()
engine.msvd()