-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_utils.py
More file actions
123 lines (106 loc) · 3.43 KB
/
data_utils.py
File metadata and controls
123 lines (106 loc) · 3.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from operator import itemgetter
import random
import numpy as np
import torch
from torch.autograd import Variable
class TaskBatch(object):
def __init__(self, task_batch, size):
self.task_batch = task_batch
self.size = size
def __len__(self):
return self.size
def __iter__(self):
return self.task_batch
def sample_ids(X, size):
return np.random.choice(len(X), size)
def get_sample_from_ids(X, Y, ids):
_X = X[ids].contiguous()
_Y = Y[ids].contiguous()
return Variable(_X, requires_grad=False), Variable(_Y, requires_grad=False)
def split_array(array, splits=[-1]):
"""splits an array into `split` proportions. `split` must sum to 1."""
array_len = len(array)
inds = list(range(array_len))
random.shuffle(inds)
rtn = []
ctr = 0
def get_split(arr, idx):
if isinstance(arr, np.ndarray):
return arr[idx]
else:
return itemgetter(*idx)(arr)
for split in splits:
if split < 0:
idx=inds[ctr:]
rtn.append(get_split(array,idx))
return rtn
if split < 1.:
split = int(array_len*split)
idx = inds[ctr:ctr+split]
rtn.append(get_split(array,idx))
ctr += split
return rtn
class KShotData(object):
def __init__(self, classes_data):
self.classes_data = classes_data
self.num_classes = len(classes_data)
self.class_idx = [range(len(x)) for x in self.classes_data]
def get_class(self, cls, n=1):
ids = random.sample(self.class_idx[cls], n)
cls = self.classes_data[cls]
return cls[ids]
def split(self, splits=[.8, .1, .1]):
rtn = split_array(self.classes_data, splits)
return [KShotData(data) for data in rtn]
class KShotLoader:
def __init__(self, kshotdata, n, k, metabatch_size, transform=None):
self.data = kshotdata
self.n = n
self.k = k
self.metabatch_size = metabatch_size
self.transform = transform
@property
def train(self):
data = self.data
if hasattr(data, 'train'):
data = data.train
while True:
yield TaskBatch(self.sample_tasks(data), self.metabatch_size)
@property
def val(self):
data = self.data
if hasattr(data, 'val'):
data = data.val
while True:
yield TaskBatch(self.sample_tasks(data), self.metabatch_size)
@property
def test(self):
data = self.data
if hasattr(data, 'test'):
data = data.test
while True:
yield TaskBatch(self.sample_tasks(data), self.metabatch_size)
def sample_tasks(self, data):
"""returns task description in the format [task_id, finetune data, eval data]"""
for _ in range(self.metabatch_size):
tr_data, tr_targets, sampled_classes, class_map = self.sample_task_data(data)
yield 0, [(tr_data, tr_targets)], [self.sample_task_data(data, 1, sampled_classes, class_map)[:2]]
def sample_task_data(self, data, k=None, sampled_classes=None, class_map=None):
if k is None:
k = self.k
if sampled_classes is None:
sampled_classes = classes = random.sample(xrange(data.num_classes), self.n)
else:
classes = sampled_classes
if class_map is None:
class_map = {c: i for i, c in enumerate(sampled_classes)}
sampled_class_data = torch.cat([data.get_class(c, k) for c in classes])
if self.transform is not None:
sampled_class_data = self.transform(sampled_class_data)
if k > 1:
classes = [c for c in classes for _ in range(k)]
sampled_class_data = Variable(sampled_class_data, requires_grad=True)
targets = Variable(torch.LongTensor([class_map[c] for c in classes]), requires_grad=True)
if sampled_class_data.is_cuda:
targets = targets.cuda()
return sampled_class_data, targets, sampled_classes, class_map