From 2a751a70be7e0b6eb6cb433936029335d47f179d Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Thu, 10 Apr 2025 05:07:41 +0900 Subject: [PATCH] make subtb efficient --- energy_sampling/gflownet_losses.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/energy_sampling/gflownet_losses.py b/energy_sampling/gflownet_losses.py index 907bac84..83762e8f 100644 --- a/energy_sampling/gflownet_losses.py +++ b/energy_sampling/gflownet_losses.py @@ -68,18 +68,16 @@ def subtb(initial_state, gfn, log_reward_fn, coef_matrix, exploration_std=None, diff_logp = log_pfs - log_pbs diff_logp_padded = torch.cat( - (torch.zeros((diff_logp.shape[0], 1)).to(diff_logp), - diff_logp.cumsum(dim=-1)), - dim=1) + (torch.zeros((diff_logp.shape[0], 1)).to(diff_logp), diff_logp.cumsum(dim=-1)), + dim=1, + ) A1 = diff_logp_padded.unsqueeze(1) - diff_logp_padded.unsqueeze(2) A2 = log_fs[:, :, None] - log_fs[:, None, :] + A1 - A2 = A2 ** 2 + loss = torch.triu((A2 ** 2) * coef_matrix.unsqueeze(0), diagonal=1).sum((1, 2)) if return_exp: - return torch.stack([torch.triu(A2[i] * coef_matrix, diagonal=1).sum() for i in range(A2.shape[0])]).sum(), states, log_pfs, log_pbs, log_fs[:, -1] + return 0.5 * loss.mean(), states, log_pfs, log_pbs, log_fs[:, -1] else: - - return torch.stack([torch.triu(A2[i] * coef_matrix, diagonal=1).sum() for i in range(A2.shape[0])]).sum() - + return 0.5 * loss.mean() def bwd_mle(samples, gfn, log_reward_fn, exploration_std=None):