Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion models/blizzard/vrnn_gauss.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,33 +334,45 @@ def main(args):
shared_updates = OrderedDict()
shared_updates[step_count] = step_count + 1

# Resets / Initializes the cell-state or the memory-state of each LSTM to
# zero.
s_0 = T.switch(T.eq(T.mod(step_count, reset_freq), 0),
rnn.get_init_state(batch_size), rnn_tm1)

# Forward Propagate the input to get more complex features for
# every time step.
x_1_temp = x_1.fprop([x], params)
x_2_temp = x_2.fprop([x_1_temp], params)
x_3_temp = x_3.fprop([x_2_temp], params)
x_4_temp = x_4.fprop([x_3_temp], params)


def inner_fn(x_t, s_tm1):

# Generate the mean and standard deviation of the
# latent variables Z_t | X_t for every time-step of the LSTM.
# This is a function of the input and the hidden state of the previous
# time step.
phi_1_t = phi_1.fprop([x_t, s_tm1], params)
phi_2_t = phi_2.fprop([phi_1_t], params)
phi_3_t = phi_3.fprop([phi_2_t], params)
phi_4_t = phi_4.fprop([phi_3_t], params)
phi_mu_t = phi_mu.fprop([phi_4_t], params)
phi_sig_t = phi_sig.fprop([phi_4_t], params)

# Prior on the latent variables at every time-step
# Dependent only on the hidden-step.
prior_1_t = prior_1.fprop([s_tm1], params)
prior_2_t = prior_2.fprop([prior_1_t], params)
prior_3_t = prior_3.fprop([prior_2_t], params)
prior_4_t = prior_4.fprop([prior_3_t], params)
prior_mu_t = prior_mu.fprop([prior_4_t], params)
prior_sig_t = prior_sig.fprop([prior_4_t], params)

# Sample from the latent distibution with mean phi_mu_t
# and std phi_sig_t
z_t = Gaussian_sample(phi_mu_t, phi_sig_t)

# h_t = f(h_(t-1)), z_t, x_t)
z_1_t = z_1.fprop([z_t], params)
z_2_t = z_2.fprop([z_1_t], params)
z_3_t = z_3.fprop([z_2_t], params)
Expand All @@ -370,6 +382,7 @@ def inner_fn(x_t, s_tm1):

return s_t, phi_mu_t, phi_sig_t, prior_mu_t, prior_sig_t, z_4_t

# Iterate over every time-step
((s_temp, phi_mu_temp, phi_sig_temp, prior_mu_temp, prior_sig_temp, z_4_temp), updates) =\
theano.scan(fn=inner_fn,
sequences=[x_4_temp],
Expand All @@ -380,6 +393,10 @@ def inner_fn(x_t, s_tm1):

shared_updates[rnn_tm1] = s_temp[-1]
s_temp = concatenate([s_0[None, :, :], s_temp[:-1]], axis=0)

# Generate the output distribution at every time-step.
# This is as a function of the latent variables and the hidden-state at
# every time-step.
theta_1_temp = theta_1.fprop([z_4_temp, s_temp], params)
theta_2_temp = theta_2.fprop([theta_1_temp], params)
theta_3_temp = theta_3.fprop([theta_2_temp], params)
Expand All @@ -395,13 +412,16 @@ def inner_fn(x_t, s_tm1):
nll_upper_bound = recon_term + kl_term
nll_upper_bound.name = 'nll_upper_bound'

# Forward-propagation of the validation data.
m_x_1_temp = x_1.fprop([m_x], params)
m_x_2_temp = x_2.fprop([m_x_1_temp], params)
m_x_3_temp = x_3.fprop([m_x_2_temp], params)
m_x_4_temp = x_4.fprop([m_x_3_temp], params)

m_s_0 = rnn.get_init_state(m_batch_size)

# Get the hidden-states, conditional mean, standard deviation, prior mean
# and prior standard deviation of the latent variables at every time-step.
((m_s_temp, m_phi_mu_temp, m_phi_sig_temp, m_prior_mu_temp, m_prior_sig_temp, m_z_4_temp), m_updates) =\
theano.scan(fn=inner_fn,
sequences=[m_x_4_temp],
Expand All @@ -410,6 +430,8 @@ def inner_fn(x_t, s_tm1):
for k, v in m_updates.iteritems():
k.default_update = v

# Get the inferred mean (X_t | Z_t) at every time-step of the validation
# data.
m_s_temp = concatenate([m_s_0[None, :, :], m_s_temp[:-1]], axis=0)
m_theta_1_temp = theta_1.fprop([m_z_4_temp, m_s_temp], params)
m_theta_2_temp = theta_2.fprop([m_theta_1_temp], params)
Expand All @@ -418,6 +440,7 @@ def inner_fn(x_t, s_tm1):
m_theta_mu_temp = theta_mu.fprop([m_theta_4_temp], params)
m_theta_sig_temp = theta_sig.fprop([m_theta_4_temp], params)

# Compute the data log-likelihood + KL-divergence on the validation data.
m_kl_temp = KLGaussianGaussian(m_phi_mu_temp, m_phi_sig_temp, m_prior_mu_temp, m_prior_sig_temp)

m_recon = Gaussian(m_x, m_theta_mu_temp, m_theta_sig_temp)
Expand Down