-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_loader.py
More file actions
69 lines (59 loc) · 2.75 KB
/
data_loader.py
File metadata and controls
69 lines (59 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from torchvision import datasets, transforms
import torch
def load_data(data_folder, batch_size, train, num_workers=0, **kwargs):
transform = {
'train': transforms.Compose(
[transforms.Resize([256, 256]),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])]),
'test': transforms.Compose(
[transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
}
data = datasets.ImageFolder(root=data_folder, transform=transform['train' if train else 'test'])
data_loader = get_data_loader(data, batch_size=batch_size,
shuffle=True if train else False,
num_workers=num_workers, **kwargs, drop_last=True if train else False)
n_class = len(data.classes)
return data_loader, n_class
def get_data_loader(dataset, batch_size, shuffle=True, drop_last=False, num_workers=0, infinite_data_loader=False, **kwargs):
if not infinite_data_loader:
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last, num_workers=num_workers, **kwargs)
else:
return InfiniteDataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=drop_last, num_workers=num_workers, **kwargs)
class _InfiniteSampler(torch.utils.data.Sampler):
"""Wraps another Sampler to yield an infinite stream."""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
for batch in self.sampler:
yield batch
class InfiniteDataLoader:
def __init__(self, dataset, batch_size, shuffle=True, drop_last=False, num_workers=0, weights=None, **kwargs):
if weights is not None:
sampler = torch.utils.data.WeightedRandomSampler(weights,
replacement=False,
num_samples=batch_size)
else:
sampler = torch.utils.data.RandomSampler(dataset,
replacement=False)
batch_sampler = torch.utils.data.BatchSampler(
sampler,
batch_size=batch_size,
drop_last=drop_last)
self._infinite_iterator = iter(torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=_InfiniteSampler(batch_sampler)
))
def __iter__(self):
while True:
yield next(self._infinite_iterator)
def __len__(self):
return 0 # Always return 0