-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
128 lines (123 loc) · 4.19 KB
/
run.py
File metadata and controls
128 lines (123 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
import torch
from torch.autograd import Variable,grad
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from PIL import Image,ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import seaborn as sns
import glob
from torchvision.transforms import Compose,ToTensor,ToPILImage,Resize,Normalize,CenterCrop
from torch.utils.data import Dataset, DataLoader
from models import D_net,G_net
import visdom
viz = visdom.Visdom()
data_dir ='Data/'
image_size=[3,128,128]
small_image_size=[3,57,57]
code = 'beta'
use_GPU=torch.cuda.is_available()
batch_size=64
device=1 # GPU device
latent_dim = 128
imgs_dir = glob.glob(data_dir+'*')
imgs = [Image.open(fil).resize(image_size[1:]) for fil in imgs_dir]
class PK_DATASET(Dataset):
def __init__(self,imgs):
self.data=imgs
self.trans=Compose([
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.trans(self.data[idx])
def get_noise(batch_size=batch_size):
return Variable(torch.rand(batch_size,latent_dim))
Pk_dataset=PK_DATASET(imgs)
label='WGan'
img_graph = viz.image(torch.ones(image_size[1:]),
opts=dict(title=label+' generated img '+str(device)))
loss_graph=viz.line(
Y=np.zeros((1,3)),
opts=dict(
fillarea=False,
showlegend=True,
legend=['D_loss','G_loss','W-distance'],
width=400,
height=400,
xlabel='Iter',
ylabel='Loss',
# ytype='log',
title=label+' - loss curve '+str(device),
))
d_iter = 1
g_iter = 1
epoch = 5000
Pk_dataloader=DataLoader(Pk_dataset,batch_size=batch_size,num_workers=1,shuffle=True)
d_model = D_net()
g_model = G_net(latent_dim)
d_model.apply(weights_init)
g_model.apply(weights_init)
d_optimizer = Adam(d_model.parameters(),lr=1e-4,betas=[0.01,.9])
g_optimizer = Adam(g_model.parameters(),lr=1e-4,betas=[0.01,.9])
if use_GPU:
d_model.cuda(device)
g_model.cuda(device)
for e in range(pause,epoch):
for data in Pk_dataloader:
for i in range(d_iter):
#real data
for p in d_model.parameters():
p.requires_grad = True
d_model.zero_grad()
true_data = Variable(data)
if use_GPU:
true_data=true_data.cuda(device)
d_true_score = d_model(true_data)
true_loss = -d_true_score.mean()
# true_loss.backward()
# d_optimizer.step()
#fake data
# d_model.zero_grad()
noise = get_noise(true_data.size()[0])
if use_GPU:
noise = noise.cuda(device)
fake_data = g_model(noise)
d_fake_score = d_model(fake_data)
fake_loss = d_fake_score.mean()
# fake_loss.backward()
w_loss = loss_with_penalty(d_model,true_data.data,fake_data.data)
# w_loss.backward()
loss =true_loss+fake_loss+w_loss
loss.backward()
d_optimizer.step()
for i in range(g_iter):
#train G
g_model.zero_grad()
for p in d_model.parameters():
p.requires_grad = False
noise = get_noise()
if use_GPU:
noise = noise.cuda(device)
fake_data = g_model(noise)
g_score = d_model(fake_data)
g_loss = -g_score.mean()
g_loss.backward()
g_optimizer.step()
dloss=float(true_loss) + float(fake_loss) + float(w_loss)
gloss=float(g_loss)
w_distance=-float(true_loss)-float(fake_loss)
viz.line(Y=np.array((dloss,gloss,w_distance)).reshape(1,-1),
X=np.array([[e,e,e]]),
win=loss_graph,
update='append')
fake_imgs=fake_data.cpu().data*0.5 +0.5
viz.image(fake_imgs[3],win=img_graph)
if e%10==0:
img = ToPILImage()(fake_imgs[3])
img.save('result/G_result/iter_%d.png'%e)
torch.save(d_model,'result/D_checkpoint/iter_d_%d.pt'%e)
torch.save(g_model,'result/G_checkpoint/iter_g_%d.pt'%e)