-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
Hi ,thanks for sharing the code! I am wondering why do you do "
else:
ldj = self.log_likelihood(z)
return z, ldj
" in the following code.
if not reverse:
ldj = z.new_zeros(batch_size, )
z = z.reshape((batch_size, self.d_in))
device = z.device
print("checking z ssssshape", z.shape) # [1, 192])
print("checking self.training ", self.training) # False
if self.training:
t = th.randint(0, self.T, (batch_size, ), device=device).long()
print("checking t", t)
# loss mark
ldj = -self.training_losses(z, t, x_cat, **kwargs)['loss']
z = z.reshape((batch_size, set_size, hidden_dim))
return z, ldj
else:
ldj = self.log_likelihood(z)
return z, ldj
else:
ldj = self.nll(z)
Metadata
Metadata
Assignees
Labels
No labels