forked from eunsu-park/solar_euv_generation
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
227 lines (169 loc) · 9.05 KB
/
train.py
File metadata and controls
227 lines (169 loc) · 9.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import numpy as np
import os, glob
from imageio import imsave
from random import shuffle
from option import option_train
import keras.backend as K
import tensorflow as tf
from keras.optimizers import Adam
import time
os.environ['KERAS_BACKEND'] = 'tensorflow'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
class train(option_train):
def __init__(self):
super(train, self).__init__()
def build_network(self):
from networks import PATCH_DISCRIMINATOR as network_D
self.network_D = network_D(self.isize, self.ch_input, self.ch_output, self.layer_max_d)
if self.name_model == 'UNET' :
from networks import UNET as network_G
elif self.name_model == 'SUNET' :
from networks import SUNET as network_G
elif self.name_model == 'SSUNET' :
from networks import SSUNET as network_G
elif self.name_model == 'CSSUNET' :
from networks import SUNET as network_G
self.network_G = network_G(self.isize, self.ch_input, self.ch_output)
def make_directory(self):
os.makedirs(self.root_ckpt, exist_ok=True)
os.makedirs(self.root_snap, exist_ok=True)
os.makedirs(self.root_test, exist_ok=True)
def make_dataset(self):
path_train_input = '%s/train/%s/%s'%(self.root_data, self.instr_input, self.name_input)
path_train_output = '%s/train/%s/%s'%(self.root_data, self.instr_output, self.name_output)
list_train_input = sorted(glob.glob('%s/*/*/*/*.npy'%path_train_input))
list_train_output = sorted(glob.glob('%s/*/*/*/*.npy'%path_train_output))
nb_train_input = len(list_train_input)
nb_train_output = len(list_train_output)
if nb_train_input != nb_train_output :
raise RuntimeError('# of train input(%d) and output(%d) are differnent'%(nb_train_input, nb_train_output))
self.list_train = list(zip(list_train_input, list_train_output))
self.nb_train = nb_train_input
path_validation_input = '%s/validation/%s/%s'%(self.root_data, self.instr_input, self.name_input)
list_validation_input = sorted(glob.glob('%s/*/*/*/*.npy'%path_validation_input))
nb_validation_input = len(list_validation_input)
self.list_validation = list_validation_input
self.nb_validation = nb_validation_input
path_test_input = '%s/test/%s/%s'%(self.root_data, self.instr_input, self.name_input)
list_test_input = sorted(glob.glob('%s/*/*/*/*.npy'%path_test_input))
nb_test_input = len(list_test_input)
self.list_test = list_test_input
self.nb_test = nb_test_input
def train_batch_generator(self):
data_AB = self.list_train
length = self.nb_train
epoch = i = 0
tmpsz = None
while True :
sz = tmpsz if tmpsz else self.bsize
if i + sz > length :
shuffle(data_AB)
i = 0
epoch += 1
batch_A = np.array([])
batch_B = np.array([])
for j in range(i, i+sz):
data_A = self.make_tensor_input(data_AB[j][0])
data_B = self.make_tensor_output(data_AB[j][1])
batch_A = np.append(batch_A, data_A)
batch_B = np.append(batch_B, data_B)
batch_A.shape = (self.bsize, self.isize, self.isize, self.ch_input)
batch_B.shape = (self.bsize, self.isize, self.isize, self.ch_output)
if self.shake :
batch_A, batch_B = self.shake_tensor(batch_A, batch_B)
i += sz
tmpsz = yield epoch, batch_A.astype(np.float32), batch_B.astype(np.float32)
## generate AIA-like image from HMI data in validation data set
def run_validation(self):
path_snap = '%s/iter_%07d'%(self.root_snap, self.iter_gen)
os.makedirs(path_snap, exist_ok=True)
for file_A in self.list_validation :
real_A = self.make_tensor_input(file_A)
fake_B = self.network_G.predict(real_A)
fake_B, fake_B_png = self.make_output(fake_B)
name_save = '%s.%s'%(self.mode, file_A.split('/')[-1][-23:-4])
np.save('%s/%s.npy'%(path_snap, name_save), fake_B)
imsave('%s/%s.png'%(path_snap, name_save), fake_B_png)
print('Validation snaps (%d images) are saved in %s'%(self.nb_validation, path_snap))
## generate AIA-like image from HMI data in test data set
def run_test(self):
path_test = '%s/iter_%07d'%(self.root_test, self.iter_gen)
os.makedirs(path_test, exist_ok=True)
for file_A in self.list_test :
real_A = self.make_tensor_input(file_A)
fake_B = self.network_G.predict(real_A)
fake_B, fake_B_png = self.make_output(fake_B)
name_save = '%s.%s'%(self.mode, file_A.split('/')[-1][-23:-4])
np.save('%s/%s.npy'%(path_test, name_save), fake_B)
imsave('%s/%s.png'%(path_test, name_save), fake_B_png)
print('Test results (%d images) are saved in %s'%(self.nb_test, path_test))
def run(self):
self.make_directory()
self.make_dataset()
train_batch_generator = self.train_batch_generator()
os.environ['CUDA_VISIBLE_DEVICES'] = self.gpu_id
K.set_image_data_format('channels_last')
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
tf.Session(config = config)
self.build_network()
real_A = self.network_G.input
fake_B = self.network_G.output
real_B = self.network_D.inputs[1]
output_D_real = self.network_D([real_A, real_B])
output_D_fake = self.network_D([real_A, fake_B])
loss_FN = lambda output, target : -K.mean(K.log(output+1e-12)*target+K.log(1-output+1e-12)*(1-target))
loss_D_real = loss_FN(output_D_real, K.ones_like(output_D_real))
loss_D_fake = loss_FN(output_D_fake, K.zeros_like(output_D_fake))
loss_G_fake = loss_FN(output_D_fake, K.ones_like(output_D_fake))
loss_L = K.mean(K.abs(fake_B-real_B))
loss_D = loss_D_real + loss_D_fake
training_updates_D = Adam(lr = 2e-4, beta_1 = 0.5).get_updates(self.network_D.trainable_weights, [], loss_D)
network_D_train = K.function([real_A, real_B], [loss_D/2.0], training_updates_D)
loss_G = loss_G_fake + 100 * loss_L
training_updates_G = Adam(lr = 2e-4, beta_1 = 0.5).get_updates(self.network_G.trainable_weights, [], loss_G)
network_G_train = K.function([real_A, real_B], [loss_G_fake, loss_L], training_updates_G)
t0 = time.time()
t1 = time.time()
self.iter_gen, epoch = 0, 0
err_L, err_G, err_D = 0, 0, 0
err_L_sum, err_G_sum, err_D_sum = 0, 0, 0
print('\n--------------------------------\n')
print('\nNow start below session!\n')
print('Mode: %s'%self.mode)
print('Checkpoint save path: %s'%(self.root_ckpt))
print('Validation snap save path: %s'%(self.root_snap))
print('Test result save path: %s'%(self.root_test))
print('# of train, validation, and test datasets : %d, %d, %d'%(self.nb_train, self.nb_validation, self.nb_test))
print('\n--------------------------------\n')
while self.iter_gen <= self.iter_max :
epoch, train_A, train_B = next(train_batch_generator)
err_G, err_L = network_G_train([train_A, train_B])
err_D, = network_D_train([train_A, train_B])
err_D_sum += err_D
err_G_sum += err_G
err_L_sum += err_L
self.iter_gen += self.bsize
if self.iter_gen % self.iter_display == 0:
err_D_mean = err_D_sum/self.iter_display
err_G_mean = err_G_sum/self.iter_display
err_L_mean = err_L_sum/self.iter_display
print('[%d][%d/%d] LOSS_D: %5.3f LOSS_G: %5.3f LOSS_L: %5.3f T: %dsec/%dits, Total T: %d'
% (epoch, self.iter_gen, self.iter_max,
err_D_mean, err_G_mean, err_L_mean,
time.time()-t1, self.iter_display, time.time()-t0))
err_L_sum, err_G_sum, err_D_sum = 0, 0, 0
t1 = time.time()
if self.iter_gen % self.iter_save == 0:
dst_model_G = '%s/%s.iter.%07d.G.h5'%(self.root_ckpt, self.mode, self.iter_gen)
dst_model_D = '%s/%s.iter.%07d.D.h5'%(self.root_ckpt, self.mode, self.iter_gen)
self.network_G.save(dst_model_G)
self.network_D.save(dst_model_D)
print('network_G and network_D are saved under %s'%(self.root_ckpt))
self.run_validation()
self.run_test()
t1 = time.time()
# run train
if __name__ == '__main__' :
do_train = train()
do_train.run()