-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathbalanced_batch_sampler.py
More file actions
51 lines (45 loc) · 2.36 KB
/
balanced_batch_sampler.py
File metadata and controls
51 lines (45 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, BatchSampler
class BalancedBatchSampler(BatchSampler):
"""
BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
Returns batches of size n_classes * n_samples
"""
def __init__(self, labels, n_classes, n_samples):
# label: unique id per datapoint, e.g. path
self.labels = labels
self.labels_set = list(set(self.labels.numpy()))
self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
for label in self.labels_set}
for l in self.labels_set:
np.random.shuffle(self.label_to_indices[l])
self.used_label_indices_count = {label: 0 for label in self.labels_set}
self.count = 0
self.n_classes = n_classes
self.n_samples = n_samples
self.n_dataset = len(self.labels)
self.batch_size = self.n_samples * self.n_classes
def __iter__(self):
self.count = 0
while self.count + self.batch_size < self.n_dataset:
classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
indices = []
for class_ in classes:
indices.extend(self.label_to_indices[class_][
self.used_label_indices_count[class_]:self.used_label_indices_count[
class_] + self.n_samples])
self.used_label_indices_count[class_] += self.n_samples
if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
np.random.shuffle(self.label_to_indices[class_])
self.used_label_indices_count[class_] = 0
yield indices
self.count += self.n_classes * self.n_samples
def __len__(self):
return self.n_dataset // self.batch_size
# train_labels = torch.tensor([item[2] for item in train_dataset])
# train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler)
# test_labels = torch.tensor([item[2] for item in test_dataset])
# train_sampler = BalancedBatchSampler(train_labels, BATCH_SIZE, 1)
# test_sampler = BalancedBatchSampler(test_labels, BATCH_SIZE, 1)
# test_dataloader = DataLoader(test_dataset, batch_sampler=test_sampler)