forked from Naresh1318/Video_Frame_Interpolation
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathimg2img.py
More file actions
186 lines (151 loc) · 12 KB
/
img2img.py
File metadata and controls
186 lines (151 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import numpy as np
import tensorflow as tf
import mission_control as mc
import ops
import utils
import sys
# Placeholders
input_image = tf.placeholder(dtype=tf.float32, shape=[None, 288, 352, 3], name='Input_image')
target_image = tf.placeholder(dtype=tf.float32, shape=[None, 288, 352, 3], name='Target_image')
global_step = tf.placeholder(dtype=tf.int64, shape=[], name="Global_Step")
train_data, train_target, test_data, test_target, mean_img = utils.generate_dataset_from_video(mc.video_path)
def generator(x, reuse=False):
if reuse:
tf.get_variable_scope().reuse_variables()
# Encoder
conv_1 = ops.lrelu(ops.cnn_2d(x, weight_shape=[4, 4, 3, 64], strides=[1, 2, 2, 1], name='g_e_conv_1'))
conv_2 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_1, weight_shape=[4, 4, 64, 128],
strides=[1, 2, 2, 1], name='g_e_conv_2'),
center=True, scale=True, is_training=True, scope='g_e_batch_Norm_2'))
conv_3 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_2, weight_shape=[4, 4, 128, 256],
strides=[1, 2, 2, 1], name='g_e_conv_3'),
center=True, scale=True, is_training=True, scope='g_e_batch_Norm_3'))
conv_4 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_3, weight_shape=[4, 4, 256, 512],
strides=[1, 2, 2, 1], name='g_e_conv_4'),
center=True, scale=True, is_training=True, scope='g_e_batch_Norm_4'))
conv_5 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_4, weight_shape=[4, 4, 512, 512],
strides=[1, 2, 2, 1], name='g_e_conv_5'),
center=True, scale=True, is_training=True, scope='g_e_batch_Norm_5'))
conv_6 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_5, weight_shape=[4, 4, 512, 512],
strides=[1, 2, 2, 1], name='g_e_conv_6'),
center=True, scale=True, is_training=True, scope='g_e_batch_Norm_6'))
conv_7 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_6, weight_shape=[4, 4, 512, 512],
strides=[1, 2, 2, 1], name='g_e_conv_7'),
center=True, scale=True, is_training=True, scope='g_e_batch_Norm_7'))
conv_8 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_7, weight_shape=[4, 4, 512, 512],
strides=[1, 2, 2, 1], name='g_e_conv_8'),
center=True, scale=True, is_training=True, scope='g_e_batch_Norm_8'))
# Decoder
dconv_1 = ops.lrelu(tf.nn.dropout(ops.batch_norm(ops.cnn_2d_trans(conv_8, weight_shape=[2, 2, 512, 512], strides=[1, 2, 2, 1], output_shape=[mc.batch_size, conv_8.get_shape()[1].value+1, conv_8.get_shape()[2].value+1, 512], name='g_d_dconv_1'), center=True, scale=True, is_training=True, scope='g_d_batch_Norm_1'), keep_prob=0.5))
dconv_1 = tf.concat([dconv_1, conv_7], axis=3)
dconv_2 = ops.lrelu(tf.nn.dropout(ops.batch_norm(ops.cnn_2d_trans(dconv_1, weight_shape=[4, 4, 512, 1024], strides=[1, 2, 2, 1], output_shape=[mc.batch_size, dconv_1.get_shape()[1].value*2-1, dconv_1.get_shape()[2].value*2, 512], name='g_d_dconv_2'), center=True, scale=True, is_training=True, scope='g_d_batch_Norm_2'), keep_prob=0.5))
dconv_2 = tf.concat([dconv_2, conv_6], axis=3)
dconv_3 = ops.lrelu(tf.nn.dropout(ops.batch_norm(ops.cnn_2d_trans(dconv_2, weight_shape=[4, 4, 512, 1024], strides=[1, 2, 2, 1], output_shape=[mc.batch_size, dconv_2.get_shape()[1].value*2-1, dconv_2.get_shape()[2].value*2-1, 512], name='g_d_dconv_3'), center=True, scale=True, is_training=True, scope='g_d_batch_Norm_3'), keep_prob=0.5))
dconv_3 = tf.concat([dconv_3, conv_5], axis=3)
dconv_4 = ops.lrelu(ops.batch_norm(ops.cnn_2d_trans(dconv_3, weight_shape=[4, 4, 512, 1024], strides=[1, 2, 2, 1], output_shape=[mc.batch_size, dconv_3.get_shape()[1].value*2, dconv_3.get_shape()[2].value*2, 512], name='g_d_dconv_4'), center=True, scale=True, is_training=True, scope='g_d_batch_Norm_4'))
dconv_4 = tf.concat([dconv_4, conv_4], axis=3)
dconv_5 = ops.lrelu(ops.batch_norm(ops.cnn_2d_trans(dconv_4, weight_shape=[4, 4, 256, 1024], strides=[1, 2, 2, 1], output_shape=[mc.batch_size, dconv_4.get_shape()[1].value*2, dconv_4.get_shape()[2].value*2, 256], name='g_d_dconv_5'), center=True, scale=True, is_training=True, scope='g_d_batch_Norm_5'))
dconv_5 = tf.concat([dconv_5, conv_3], axis=3)
dconv_6 = ops.lrelu(ops.batch_norm(ops.cnn_2d_trans(dconv_5, weight_shape=[4, 4, 128, 512], strides=[1, 2, 2, 1], output_shape=[mc.batch_size, dconv_5.get_shape()[1].value*2, dconv_5.get_shape()[2].value*2, 128], name='g_d_dconv_6'), center=True, scale=True, is_training=True, scope='g_d_batch_Norm_6'))
dconv_6 = tf.concat([dconv_6, conv_2], axis=3)
dconv_7 = ops.lrelu(ops.batch_norm(ops.cnn_2d_trans(dconv_6, weight_shape=[4, 4, 64, 256], strides=[1, 2, 2, 1], output_shape=[mc.batch_size, dconv_6.get_shape()[1].value*2, dconv_6.get_shape()[2].value*2, 64], name='g_d_dconv_7'), center=True, scale=True, is_training=True, scope='g_d_batch_Norm_7'))
dconv_7 = tf.concat([dconv_7, conv_1], axis=3)
dconv_8 = tf.nn.tanh(ops.cnn_2d_trans(dconv_7, weight_shape=[4, 4, 3, 128], strides=[1, 2, 2, 1], output_shape=[mc.batch_size, dconv_7.get_shape()[1].value*2, dconv_7.get_shape()[2].value*2, 3], name='g_d_dconv_8'))
return dconv_8
def discriminator(x, reuse=False):
if reuse:
tf.get_variable_scope().reuse_variables()
conv_1 = ops.lrelu(ops.batch_norm(ops.cnn_2d(x, weight_shape=[4, 4, 6, 64],
strides=[1, 2, 2, 1], name='dis_conv_1'),
center=True, scale=True, is_training=True, scope='dis_batch_Norm_1'))
conv_2 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_1, weight_shape=[4, 4, 64, 128],
strides=[1, 2, 2, 1], name='dis_conv_2'),
center=True, scale=True, is_training=True, scope='dis_batch_Norm_2'))
conv_3 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_2, weight_shape=[4, 4, 128, 256],
strides=[1, 2, 2, 1], name='dis_conv_3'),
center=True, scale=True, is_training=True, scope='dis_batch_Norm_3'))
conv_4 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_3, weight_shape=[4, 4, 256, 512],
strides=[1, 2, 2, 1], name='dis_conv_4'),
center=True, scale=True, is_training=True, scope='dis_batch_Norm_4'))
conv_5 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_4, weight_shape=[4, 4, 512, 512],
strides=[1, 2, 2, 1], name='dis_conv_5'),
center=True, scale=True, is_training=True, scope='dis_batch_Norm_5'))
conv_6 = ops.lrelu(ops.batch_norm(ops.cnn_2d(conv_5, weight_shape=[4, 4, 512, 512],
strides=[1, 2, 2, 1], name='dis_conv_6'),
center=True, scale=True, is_training=True, scope='dis_batch_Norm_6'))
output = ops.dense(conv_6, 5 * 6, 1, name='dis_output')
return output
def train():
with tf.variable_scope(tf.get_variable_scope()):
generated_image = generator(input_image)
discriminator_real_input = tf.concat([input_image, target_image], axis=3)
discriminator_fake_input = tf.concat([input_image, generated_image], axis=3)
with tf.variable_scope(tf.get_variable_scope()):
real_discriminator_op = discriminator(discriminator_real_input)
fake_discriminator_op = discriminator(discriminator_fake_input, reuse=True)
# GAN losses
generator_fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits
(labels=tf.ones_like(fake_discriminator_op), logits=fake_discriminator_op))
discriminator_fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits
(labels=tf.zeros_like(fake_discriminator_op),
logits=fake_discriminator_op))
discriminator_real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits
(labels=tf.ones_like(real_discriminator_op), logits=real_discriminator_op))
eps = 1e-5
l1_loss = tf.reduce_mean(tf.abs(generated_image - target_image + eps))
discriminator_loss = discriminator_fake_loss + discriminator_real_loss
generator_loss = mc.discriminator_weight * generator_fake_loss + mc.l1_weight * l1_loss
# Collect trainable parameter
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'dis_' in var.name]
g_vars = [var for var in t_vars if 'g_' in var.name]
g_learning_rate = tf.train.exponential_decay(mc.generator_lr, global_step,
1, 0.96, staircase=True)
d_learning_rate = tf.train.exponential_decay(mc.discriminator_lr, global_step,
1, 0.96, staircase=True)
generator_optimizer = tf.train.AdamOptimizer(g_learning_rate, beta1=mc.beta1).minimize(generator_loss,
var_list=g_vars)
discriminator_optimizer = tf.train.AdamOptimizer(d_learning_rate, beta1=mc.beta1).minimize(discriminator_loss,
var_list=d_vars)
# Summaries
tf.summary.scalar('l1_loss', l1_loss)
tf.summary.scalar('discriminator_loss', discriminator_loss)
tf.summary.scalar('generator_fake_loss', generator_fake_loss)
tf.summary.scalar('generator_loss', generator_loss)
tf.summary.scalar('generator_lr', g_learning_rate)
tf.summary.scalar('discriminator_lr', d_learning_rate)
tf.summary.image('generated_image', generated_image)
tf.summary.image('input_image', input_image)
tf.summary.image('target_image', target_image)
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
file_writer = tf.summary.FileWriter(logdir='./Tensorboard', graph=sess.graph)
step = 1
for e in range(mc.n_epochs):
n_batches = int(len(train_data) / mc.batch_size)
for b in range(n_batches):
batch_indx = np.random.permutation(len(train_data))[:mc.batch_size]
train_data_batch = [train_data[t] for t in batch_indx]
train_target_batch = [train_target[t] for t in batch_indx]
for i in range(1):
sess.run(discriminator_optimizer,
feed_dict={input_image: train_data_batch, target_image: train_target_batch,
global_step: step})
for i in range(1):
sess.run(generator_optimizer,
feed_dict={input_image: train_data_batch, target_image: train_target_batch,
global_step: step})
s, l, dl, gl = sess.run([summary_op, l1_loss, discriminator_loss, generator_fake_loss],
feed_dict={input_image: train_data_batch, target_image: train_target_batch,
global_step: step})
print("\rEpoch: {}/{} \t Batch: {}/{} l1_loss: {} disc_loss: {} gen_loss: {}".format(e, mc.n_epochs, b,
n_batches, l, dl,
gl))
sys.stdout.flush()
file_writer.add_summary(s)
step += 1
if __name__ == '__main__':
train()