Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion code/kl_divergences.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ def diagonal_gaussian_kl(mean, std):
:return: The KL divergence.
"""
var = std ** 2
return 0.5 * (mx.sym.sum(1 + mx.sym.log(var) - mean ** 2 - var))
return -0.5 * (mx.sym.sum(1 + mx.sym.log(var) - mean ** 2 - var, axis=1))
79 changes: 73 additions & 6 deletions code/run_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
DEFAULT_LEARNING_RATE = 0.0003

data_names = ['train', 'valid', 'test']
train_set = ['train', 'valid']
train_set = ['train']
test_set = ['test']
data_dir = join(os.curdir, "binary_mnist")

Expand Down Expand Up @@ -86,25 +86,92 @@ def train_model(generator_layers: List[int],
mnist = load_data(train=True, logger=logger)
train_iter = mx.io.NDArrayIter(data=mnist['train'], data_name="data", label=mnist["train"], label_name="label",
batch_size=batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(data=mnist['valid'], data_name="data", label=mnist['valid'], label_name="label",
batch_size=batch_size)

vae = construct_vae(latent_type="gaussian", likelihood="bernoulliProd", generator_layer_sizes=generator_layers,
infer_layer_size=inference_layers, latent_variable_size=latent_size,
data_dims=mnist['train'].shape[1], generator_act_type='tanh', infer_act_type='tanh')

module = mx.module.Module(vae.train(mx.sym.Variable("data"), mx.sym.Variable('label')),
data_names=[train_iter.provide_data[0][0]],
label_names=["label"], context=ctx,
label_names=[train_iter.provide_label[0][0]], context=ctx,
logger=logger)

logger.info("Starting to train")

# module.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label,
# for_training=True, force_rebind=True)
# # if monitor is not None:
# # self.install_monitor(monitor)
# module.init_params()
# module.init_optimizer(optimizer=optimiser, optimizer_params={"learning_rate" : learning_rate})
#
# # if not isinstance(eval_metric, metric.EvalMetric):
# # eval_metric = metric.create(eval_metric)
#
# ################################################################################
# # training loop
# ################################################################################
# import time
# for epoch in range(0, epochs):
# tic = time.time()
# #eval_metric.reset()
# nbatch = 0
# data_iter = iter(train_iter)
# end_of_batch = False
# next_data_batch = next(data_iter)
# while not end_of_batch:
# data_batch = next_data_batch
# # if monitor is not None:
# # monitor.tic()
# module.forward_backward(data_batch)
# module.update()
# try:
# # pre fetch next batch
# next_data_batch = next(data_iter)
# module.prepare(next_data_batch)
# except StopIteration:
# end_of_batch = True
#
# # self.update_metric(eval_metric, data_batch.label)
#
# # if monitor is not None:
# # monitor.toc_print()
#
# # if batch_end_callback is not None:
# # batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
# # eval_metric=eval_metric,
# # locals=locals())
# # for callback in _as_list(batch_end_callback):
# # callback(batch_end_params)
# nbatch += 1
# print(nbatch)
#
# # one epoch of training is finished
# # for name, val in eval_metric.get_name_value():
# # module.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
# toc = time.time()
# module.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))
#
# # sync aux params across devices
# arg_params, aux_params = module.get_params()
# module.set_params(arg_params, aux_params)
#
# # if epoch_end_callback is not None:
# # for callback in _as_list(epoch_end_callback):
# # callback(epoch, self.symbol, arg_params, aux_params)
#
# # end of 1 epoch, reset the data-iter for another epoch
# train_iter.reset()

# print(module.get_outputs())
#
module.fit(train_data=train_iter, optimizer=optimiser, force_init=True, force_rebind=True, num_epoch=epochs,
optimizer_params={'learning_rate': learning_rate},
# validation_metric=mx.metric.Perplexity(None),
# eval_data=val_iter,
batch_end_callback=mx.callback.Speedometer(frequent=1, batch_size=batch_size),
epoch_end_callback=mx.callback.do_checkpoint('vae'))
batch_end_callback=mx.callback.Speedometer(frequent=20, batch_size=batch_size),
epoch_end_callback=mx.callback.do_checkpoint('vae'),
eval_metric="Loss")


def load_model(model_file: str):
Expand Down
23 changes: 11 additions & 12 deletions code/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ def _preactivation(self, latent_state: mx.sym.Symbol) -> mx.sym.Symbol:
:return: The pre-activation before output activation
"""
prev_out = None
for i, hidden in enumerate(self.layer_sizes):
fc_i = mx.sym.FullyConnected(data=latent_state, num_hidden=hidden, name="gen_fc_{}".format(i))
for i, size in enumerate(self.layer_sizes):
fc_i = mx.sym.FullyConnected(data=latent_state, num_hidden=size, name="gen_fc_{}".format(i))
prev_out = mx.sym.Activation(data=fc_i, act_type=self.act_type, name="gen_act_{}".format(i))

# The output layer that gives pre_activations for multiple Bernoulli softmax between 0 and 1
fc_out = mx.sym.FullyConnected(data=prev_out, num_hidden=2 * self.data_dims, name="gen_fc_out")
fc_out = mx.sym.FullyConnected(data=prev_out, num_hidden=self.data_dims, name="gen_fc_out")

return fc_out

Expand All @@ -73,11 +73,9 @@ def generate_sample(self, latent_state: mx.sym.Symbol) -> mx.sym.Symbol:
:param latent_state: The input latent state.
:return: A vector of Bernoulli draws.
"""
act = mx.sym.Activation(data=self._generate(latent_state=latent_state), act_type=self.output_act,
act = mx.sym.Activation(data=self._preactivation(latent_state=latent_state), act_type=self.output_act,
name="gen_act_out")
act = mx.ndarray(mx.sym.split(data=act, num_outputs=self.data_dims))
out = mx.sym.maximum(data=act, axis=0)

out = act > 0.5
return out

def train(self, latent_state=mx.sym.Symbol, label=mx.sym.Symbol) -> mx.sym.Symbol:
Expand All @@ -88,9 +86,9 @@ def train(self, latent_state=mx.sym.Symbol, label=mx.sym.Symbol) -> mx.sym.Symbo
:param label: A binary vector (same as input for inference module)
:return: The loss symbol used for training
"""
output = self._preactivation(latent_state=latent_state)
output = mx.sym.reshape(data=output, shape=(-1, 2, self.data_dims))
return mx.sym.SoftmaxOutput(data=output, label=label, multi_output=True)
output = mx.sym.Activation(data=self._preactivation(latent_state=latent_state), act_type=self.output_act,
name="output_act")
return mx.sym.sum(label * mx.sym.log(output) + (1-label) * mx.sym.log(1-output), axis=[1])


class InferenceNetwork(ABC):
Expand Down Expand Up @@ -235,8 +233,9 @@ def train(self, data: mx.sym.Symbol, label: mx.sym.Symbol) -> mx.sym.Symbol:
"""
mean, std = self.inference_net.inference(data=data)
latent_state = self.inference_net.sample_latent_state(mean, std)
kl_loss = mx.sym.MakeLoss(self.kl_divergence(mean, std))
return self.generator.train(latent_state=latent_state, label=label)
kl_term = self.kl_divergence(mean, std)
log_likelihood = self.generator.train(latent_state=latent_state, label=label)
return mx.sym.MakeLoss(kl_term - log_likelihood, name="Neg_Elbo")

def generate_reconstructions(self, data: mx.sym.Symbol, n: int) -> mx.sym.Symbol:
"""
Expand Down
Loading