-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·95 lines (83 loc) · 3.39 KB
/
train.py
File metadata and controls
executable file
·95 lines (83 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils import data
from torchvision import datasets, transforms, utils
from models import LabelNet, PixelCNN
n_classes = 10 # number of classes
n_epochs = 25 # number of epochs to train
n_layers = 7 # number of convolutional layers
n_channels = 16 # number of channels
device = 'cuda:0'
def to_one_hot(y, k=10):
y = y.view(-1, 1)
y_one_hot = torch.zeros(y.numel(), k)
y_one_hot.scatter_(1, y, 1)
return y_one_hot.float()
pixel_cnn = PixelCNN(n_channels, n_layers).to(device)
label_net = LabelNet().to(device)
trainloader = data.DataLoader(datasets.MNIST('data', train=True,
download=True,
transform=transforms.ToTensor()),
batch_size=128, shuffle=True,
num_workers=1, pin_memory=True)
testloader = data.DataLoader(datasets.MNIST('data', train=False,
download=True,
transform=transforms.ToTensor()),
batch_size=128, shuffle=False,
num_workers=1, pin_memory=True)
sample = torch.Tensor(120, 1, 28, 28).to(device)
optimizer = optim.Adam(list(pixel_cnn.parameters())+
list(label_net.parameters()))
criterion = torch.nn.CrossEntropyLoss()
# Training loop from jzbontar/pixelcnn-pytorch
for epoch in range(n_epochs):
# train
err_tr = []
time_tr = time.time()
pixel_cnn.train()
label_net.train()
for inp, lab in trainloader:
lab = to_one_hot(lab)
lab_emb = label_net(lab.to(device))
inp = inp.to(device)
target = (inp[:,0] * 255).long()
loss = criterion(pixel_cnn(inp, lab_emb), target)
err_tr.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
time_tr = time.time() - time_tr
with torch.no_grad():
# compute error on test set
err_te = []
time_te = time.time()
pixel_cnn.eval()
label_net.eval()
for inp, lab in testloader:
lab = to_one_hot(lab)
lab_emb = label_net(lab.to(device))
inp = inp.to(device)
target = (inp[:,0] * 255).long()
loss = criterion(pixel_cnn(inp, lab_emb), target)
err_te.append(loss.item())
time_te = time.time() - time_te
# sample
labels = torch.arange(10).repeat(12,1).flatten()
sample.fill_(0)
for i in range(28):
for j in range(28):
out = pixel_cnn(sample, label_net(to_one_hot(labels).to(device)))
probs = F.softmax(out[:, :, i, j], dim=1)
sample[:, :, i, j] = torch.multinomial(probs, 1).float() / 255.
utils.save_image(sample, 'sample_{:02d}.png'.format(epoch+1), nrow=10, padding=0)
output_string = 'epoch: {}/{} bpp (train): {:.7f}' + \
' bpp (test): {:.7f} time (training): {:.1f}s time (testing): {:.1f}s'
print(output_string.format(epoch+1,
n_epochs,
np.mean(err_tr)/np.log(2),
np.mean(err_te)/np.log(2),
time_tr,
time_te))