Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions energy_sampling/energies/many_well.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def doublewell_logprob(self, x):
return x1_term + x2_term

def manywell_logprob(self, x):
assert x.ndim == 2
logprob = torch.stack(
[self.doublewell_logprob(x[:, i*2:i*2+2]) for i in range(self.n_wells)],
dim=1).sum(dim=1)
assert x.ndim == 2 # [batch_size, ndim]
x_reshaped = x.view(-1, self.n_wells, 2).reshape(-1, 2) # [batch_size * n_wells, 2]
logprob = self.doublewell_logprob(x_reshaped) # [batch_size * n_wells]
logprob = logprob.reshape(-1, self.n_wells).sum(dim=1) # [batch_size]
return logprob

def sample_first_dimension(self, batch_size):
Expand Down
38 changes: 30 additions & 8 deletions energy_sampling/models/gfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,11 @@ def get_trajectory_fwd(self, s, discretizer, exploration_std, log_r, pis=False):
pf_mean, pflogvars = self.split_params(pfs)

logf[:, i] = flow
if self.partial_energy:
ref_log_var = (self.t_scale * ts[:, max(1, i)]).log()
log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s ** 2)).sum(1)
logf[:, i] += (1 - ts[:, i]) * log_p_ref + ts[:, i] * log_r(s)
# Note: We instead use the vectorized version outside of the loop
# if self.partial_energy:
# ref_log_var = (self.t_scale * ts[:, max(1, i)]).log()
# log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s ** 2)).sum(1)
# logf[:, i] += (1 - ts[:, i]) * log_p_ref + ts[:, i] * log_r(s)

if exploration_std is None:
if pis:
Expand Down Expand Up @@ -192,6 +193,16 @@ def get_trajectory_fwd(self, s, discretizer, exploration_std, log_r, pis=False):
s = s_
states[:, i + 1] = s

if self.partial_energy:
assert log_r is not None
ref_log_var = (self.t_scale * ts[:, 1:-1]).log().unsqueeze(2) # (bsz, T - 1, 1)
log_p_ref = -0.5 * (
logtwopi + ref_log_var + (-ref_log_var).exp() * (states[:, 1:-1] ** 2)
).sum(-1)
logf[:, 1:-1] += (1 - ts[:, 1:-1]) * log_p_ref + ts[:, 1:-1] * log_r(
states[:, 1:-1].reshape(-1, self.dim)
).view(bsz, trajectory_length - 1)

return states, logpf, logpb, logf

def get_trajectory_bwd(self, s, discretizer, exploration_std, log_r):
Expand Down Expand Up @@ -231,10 +242,11 @@ def get_trajectory_bwd(self, s, discretizer, exploration_std, log_r):
pf_mean, pflogvars = self.split_params(pfs)

logf[:, trajectory_length - i - 1] = flow
if self.partial_energy:
ref_log_var = (self.t_scale * ts[:, max(1, trajectory_length - i - 1)]).log()
log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s ** 2)).sum(1)
logf[:, trajectory_length - i - 1] += ts[:, trajectory_length - i - 1] * log_p_ref + ts[:, i + 1] * log_r(s)
# Note: We instead use the vectorized version outside of the loop
# if self.partial_energy:
# ref_log_var = (self.t_scale * ts[:, max(1, trajectory_length - i - 1)]).log()
# log_p_ref = -0.5 * (logtwopi + ref_log_var.unsqueeze(1) + (-ref_log_var).exp().unsqueeze(1) * (s_ ** 2)).sum(1)
# logf[:, trajectory_length - i - 1] += (1 - ts[:, trajectory_length - i - 1]) * log_p_ref + ts[:, trajectory_length - i - 1] * log_r(s_)

noise = ((s - s_) - dts.unsqueeze(1) * pf_mean) / (dts.sqrt().unsqueeze(1) * (pflogvars / 2).exp())
logpf[:, trajectory_length - i - 1] = -0.5 * (noise ** 2 + logtwopi + dts.log().unsqueeze(1) + pflogvars).sum(
Expand All @@ -243,6 +255,16 @@ def get_trajectory_bwd(self, s, discretizer, exploration_std, log_r):
s = s_
states[:, trajectory_length - i - 1] = s

if self.partial_energy:
assert log_r is not None
ref_log_var = (self.t_scale * ts[:, 1:-1]).log().unsqueeze(2) # (bsz, T - 1, 1)
log_p_ref = -0.5 * (
logtwopi + ref_log_var + (-ref_log_var).exp() * (states[:, 1:-1] ** 2)
).sum(-1)
logf[:, 1:-1] += (1 - ts[:, 1:-1]) * log_p_ref + ts[:, 1:-1] * log_r(
states[:, 1:-1].reshape(-1, self.dim)
).view(bsz, trajectory_length - 1)

return states, logpf, logpb, logf

def sample(self, batch_size, discretizer, log_r):
Expand Down