-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata.py
More file actions
96 lines (87 loc) · 4.29 KB
/
data.py
File metadata and controls
96 lines (87 loc) · 4.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import numpy as np
import torchvision.datasets as datasets
from torch.utils import data
from torchvision import transforms
from PIL import Image
class ImageFolder(data.Dataset):
"""Custom Dataset compatible with prebuilt DataLoader."""
def __init__(self, root, transform=None):
"""Initialize image paths and preprocessing module."""
self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root)))
self.transform = transform
def __len__(self):
"""Return the total number of image files."""
return len(self.image_paths)
def __getitem__(self, index):
"""Read an image from a file and preprocesses it and returns."""
image_path = self.image_paths[index]
image = Image.open(image_path).convert('RGB')
if self.transform is not None:
image = self.transform(image)
return image
def get_loader(train_valid_split=None, num_workers=2, **setting):
"""Create and return Dataloader."""
config = setting['config']
image_path = setting['path']['data_root']
data_loader = None
image_size = (config['Resolution'], config['Resolution'])
transform = basic_transform(image_size)
if config['Data'] in ['MNIST', 'CIFAR10']:
train_dataset = datasets.__dict__[config['Data']](root=image_path,
train=True,
download=True,
transform=transform)
valid_dataset = datasets.__dict__[config['Data']](root=image_path,
train=False,
download=True,
transform=transform)
train_loader = data.DataLoader(dataset=train_dataset,
batch_size=config['Batch size'],
shuffle=True,
num_workers=num_workers)
valid_loader = data.DataLoader(dataset=valid_dataset,
batch_size=config['Batch size'],
shuffle=False,
num_workers=num_workers)
data_loader = {'train': train_loader, 'valid': valid_loader}
elif config['Data'] in ['CelebA', 'AFHQ']:
dataset = ImageFolder(image_path, transform)
if (train_valid_split is None) and (data_loader is None):
train_loader = data.DataLoader(dataset=dataset,
batch_size=config['Batch size'],
shuffle=True,
num_workers=num_workers)
valid_loader = None
else:
train_size = int(train_valid_split * len(dataset))
valid_size = len(dataset) - train_size
train_dataset, valid_dataset = data.random_split(dataset, [train_size, valid_size])
train_loader = data.DataLoader(dataset=train_dataset,
batch_size=config['Batch size'],
shuffle=True,
num_workers=num_workers)
valid_loader = data.DataLoader(dataset=valid_dataset,
batch_size=config['Batch size'],
shuffle=False,
num_workers=num_workers)
data_loader = {'train': train_loader, 'valid': valid_loader}
else:
raise ValueError(f"Dataset {config['Data']} is not supported.")
return data_loader
def basic_transform(image_size):
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1),
])
return transform
def reverse_transform():
reverse_transform = transforms.Compose([
transforms.Lambda(lambda t: (t + 1) / 2),
# transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
transforms.Lambda(lambda t: t * 255.),
transforms.Lambda(lambda t: t.detach().numpy()),
])
return reverse_transform