-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrain_beta_vae.py
More file actions
121 lines (97 loc) · 4.03 KB
/
train_beta_vae.py
File metadata and controls
121 lines (97 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import tensorflow as tf
import numpy as np
import os
import sys
from random import shuffle
import argparse
from utils.train_utils import *
from model import create_vae_with_elbo_loss
config = tf.ConfigProto(device_count = {'GPU': 1})
parser = argparse.ArgumentParser()
parser.add_argument('logdir', type=str,
help='Directory to store checkpoints and summaries')
parser.add_argument('train_dir', type=str,
help='Directory with training data')
parser.add_argument('test_dir', type=str,
help='Directory with test data')
args = parser.parse_args()
batch_size = 128
segment_length = 20
segment_channels = 80
encoder_units = 256
decoder_units = 256
num_epochs = 1200
num_latents = 32
beta = 8.0
segment = tf.placeholder(
tf.float32, shape=(batch_size, segment_length, segment_channels))
total_loss, outputs = create_vae_with_elbo_loss(
segment, segment_channels,
encoder_units, decoder_units, num_latents,
beta)
outs = outputs['encoder_outs']
z_prior = outputs['z_prior']
z = outputs['z_posterior_sample']
segment_pred = outputs['x_reconst_mean']
optimizer = tf.train.AdamOptimizer()
updates = optimizer.minimize(total_loss)
loss_sum = tf.summary.scalar("total_loss", total_loss)
random_segment_sum = plot_segments(segment_pred, "random_segments")
reconst_segment_sum = plot_segments(
tf.concat((segment, segment_pred), axis=2),
"reconstructed_segments")
outs = tf.expand_dims(outs[:1], -1)
outs_sum = tf.summary.image("encoder_outs", outs)
train_ds = create_batched_dataset(args.train_dir, batch_size, segment_channels,
segment_length, shuffle=False)
test_ds = create_batched_dataset(args.test_dir, batch_size, segment_channels,
segment_length, shuffle=False).repeat()
train_iterator = tf.data.Iterator.from_structure(train_ds.output_types,
train_ds.output_shapes)
next_train_batch = train_iterator.get_next()
next_test_batch = test_ds.make_one_shot_iterator().get_next()
training_init_op = train_iterator.make_initializer(train_ds)
global_i = tf.get_variable("global_i", dtype=tf.int32, trainable=False,
initializer=0)
inc_global_i = tf.assign_add(global_i, 1)
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
load(saver, sess, args.logdir)
summary_writer = tf.summary.FileWriter(args.logdir, sess.graph)
test_loss_ = np.inf
train_loss_ = np.inf
for epoch in range(num_epochs):
# Initialize an iterator over the training dataset.
sess.run(training_init_op)
while True:
sess.run(inc_global_i)
i_ = sess.run(global_i)
try:
if i_ % 100 != 0:
train_seg_ = sess.run(next_train_batch)
train_loss_, _ = sess.run(
[total_loss, updates], {segment: train_seg_})
else:
test_seg_, z_prior_ = sess.run(
[next_test_batch, z_prior])
test_loss_, loss_sum_, reconst_segment_sum_, outs_sum_ = (
sess.run([total_loss, loss_sum,
reconst_segment_sum, outs_sum],
{segment: test_seg_}))
random_segment_sum_ = sess.run(
random_segment_sum, {z: z_prior_})
summary_writer.add_summary(loss_sum_, i_)
summary_writer.add_summary(reconst_segment_sum_, i_)
summary_writer.add_summary(random_segment_sum_, i_)
summary_writer.add_summary(outs_sum_, i_)
if i_ % 10000 == 0:
save(saver, sess, args.logdir, i_)
print("Epoch: {}, step: {}, train_loss: {:.2f}, test_loss: {:.2f}"
.format(epoch, i_, train_loss_, test_loss_))
except tf.errors.OutOfRangeError:
break
except KeyboardInterrupt:
save(saver, sess, args.logdir, i_)
sess.close()
sys.exit()