Skip to content

Conversation

@josephdviviano
Copy link
Collaborator

@josephdviviano josephdviviano commented Dec 16, 2025

  • I've read the .github/CONTRIBUTING.md file
  • My code follows the typing guidelines
  • I've added appropriate tests
  • I've run pre-commit hooks locally

Description

  • Adds method for pre-training a diffusion-based pf using either a learned backward or brownian bridge from provided samples using maximum likelihood, DiffusionMLE.
  • Adds a method for fine-tuning from a prior model to a posterior distribution using the Relative Trajectory Balance loss, RelativeTrajectoryBalanceGFlowNet.
  • Adds a script showing how to use both, train_diffusion_rtb.py.
  • Adds some simple tests.
  • Minor changes to existing diffusion modules to allow for additional features such as learned variance.

@josephdviviano josephdviviano self-assigned this Dec 16, 2025
@codecov
Copy link

codecov bot commented Dec 16, 2025

Codecov Report

❌ Patch coverage is 75.68389% with 80 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.33%. Comparing base (91aae26) to head (97a2453).
⚠️ Report is 25 commits behind head on master.

Files with missing lines Patch % Lines
src/gfn/gym/diffusion_sampling.py 64.70% 27 Missing and 3 partials ⚠️
src/gfn/gflownet/mle.py 81.90% 12 Missing and 7 partials ⚠️
src/gfn/utils/modules.py 71.15% 10 Missing and 5 partials ⚠️
src/gfn/gflownet/trajectory_balance.py 70.27% 7 Missing and 4 partials ⚠️
src/gfn/estimators.py 95.12% 1 Missing and 1 partial ⚠️
src/gfn/gym/helpers/diffusion_utils.py 0.00% 2 Missing ⚠️
src/gfn/gflownet/base.py 80.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #457      +/-   ##
==========================================
+ Coverage   74.23%   74.33%   +0.09%     
==========================================
  Files          47       48       +1     
  Lines        6805     7112     +307     
  Branches      800      830      +30     
==========================================
+ Hits         5052     5287     +235     
- Misses       1449     1503      +54     
- Partials      304      322      +18     
Files with missing lines Coverage Δ
src/gfn/env.py 81.29% <100.00%> (ø)
src/gfn/gflownet/base.py 78.88% <80.00%> (-0.19%) ⬇️
src/gfn/estimators.py 91.77% <95.12%> (-0.06%) ⬇️
src/gfn/gym/helpers/diffusion_utils.py 22.22% <0.00%> (-1.78%) ⬇️
src/gfn/gflownet/trajectory_balance.py 79.48% <70.27%> (-8.32%) ⬇️
src/gfn/utils/modules.py 72.80% <71.15%> (+0.05%) ⬆️
src/gfn/gflownet/mle.py 81.90% <81.90%> (ø)
src/gfn/gym/diffusion_sampling.py 48.75% <64.70%> (+4.89%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@josephdviviano josephdviviano marked this pull request as ready for review December 17, 2025 05:52
s_curr,
s_curr * self.dt / t_curr,
torch.zeros_like(s_curr),
s_curr * (1.0 - self.dt / t_curr),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hyeok9855 this might have been the cause of the problem you had before in your code (learning slower), worth checking.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hyeok9855 actually- this was my bug!!

recalculate_all_logprobs=recalculate_all_logprobs,
)

def trajectory_log_probs_backward(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leaving these here because they might come in handy, but I don't think they're actually needed right now in this implementation.

@josephdviviano
Copy link
Collaborator Author

@copilot how can I prevent codecov from failing CI? I want it to track coverage, but not block merge requests.

Copy link
Contributor

Copilot AI commented Dec 18, 2025

@josephdviviano I've opened a new pull request, #458, to work on those changes. Once the pull request is ready, I'll request review from you.

Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds Relative Trajectory Balance (RTB) support for fine-tuning diffusion-based GFlowNets from a prior to a posterior distribution. It introduces a two-stage training pipeline: (1) maximum likelihood estimation (MLE) pre-training of a prior diffusion model using either a fixed Brownian bridge or learned backward policy, and (2) RTB-based fine-tuning that adapts the prior to match a posterior distribution weighted by a reward function.

Key changes include:

  • Implementation of RelativeTrajectoryBalanceGFlowNet for fine-tuning with the RTB loss
  • Implementation of MLEDiffusion trainer for prior pre-training via maximum likelihood
  • Support for learned variance in both forward and backward diffusion policies
  • New Gaussian mixture target distributions for RTB demonstrations (25-mode prior, 9-mode posterior)
  • Enhanced diffusion trajectory mask alignment with consistent terminal time detection

Reviewed changes

Copilot reviewed 15 out of 16 changed files in this pull request and generated 19 comments.

Show a summary per file
File Description
src/gfn/gflownet/trajectory_balance.py Adds RelativeTrajectoryBalanceGFlowNet class implementing the RTB objective
src/gfn/gflownet/mle.py New MLE trainer for diffusion pre-training with backward sampling
src/gfn/estimators.py Adds learned variance support and exploration noise handling to forward/backward estimators
src/gfn/utils/modules.py Implements DiffusionPISGradNetBackward and adds learned variance to forward module
src/gfn/gym/diffusion_sampling.py Adds Grid25GaussianMixture and Posterior9of25GaussianMixture targets; improves terminal state detection
src/gfn/gym/helpers/diffusion_utils.py Adds max_n_samples parameter to visualization utility
src/gfn/gflownet/base.py Adds helper methods for computing forward/backward trajectory log-probs separately
src/gfn/gflownet/__init__.py Exports RelativeTrajectoryBalanceGFlowNet
src/gfn/env.py Minor comment clarification on boolean masking behavior
tutorials/examples/train_diffusion_rtb.py Complete end-to-end tutorial demonstrating MLE pre-training and RTB fine-tuning
tutorials/examples/output/.gitignore Ignores generated checkpoint and visualization files
testing/test_rtb.py Unit tests for RTB loss computation and gradient flow
testing/test_environments.py Test verifying diffusion trajectory mask alignment
testing/gym/test_diffusion_sampling_rtb.py Tests for new GMM target distributions
testing/gflownet/test_mle_diffusion.py Comprehensive tests for MLE diffusion trainer
.flake8 Minor formatting (trailing newline)
Comments suppressed due to low confidence (2)

src/gfn/gym/diffusion_sampling.py:69

            self.gt_xs = self.sample(n_gt_xs, seed)

src/gfn/gym/diffusion_sampling.py:70

            self.gt_xs_log_rewards = self.log_reward(self.gt_xs)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 410 to 411
os.makedirs("output", exist_ok=True)
plt.savefig(f"output/{prefix}simple_gmm.png")
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output directory is changed from "viz" to "output" across multiple target classes. While this change is consistent, it would be better to define the output directory as a constant or configuration parameter at the module level rather than hardcoding it in multiple places.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot open a new pull request to apply changes based on this feedback

Comment on lines +1797 to +1874
class DiffusionPISGradNetBackward(nn.Module):
"""Learnable backward correction module (PIS-style) for diffusion.
Produces mean and optional log-std corrections that are tanh-scaled by
`pb_scale_range` to stay close to the analytic Brownian bridge.
"""

def __init__(
self,
s_dim: int,
harmonics_dim: int = 64,
t_emb_dim: int = 64,
s_emb_dim: int = 64,
hidden_dim: int = 64,
joint_layers: int = 2,
zero_init: bool = False,
clipping: bool = False,
gfn_clip: float = 1e4,
pb_scale_range: float = 0.1,
log_var_range: float = 4.0,
learn_variance: bool = True,
) -> None:
super().__init__()
self.s_dim = s_dim
self.out_dim = s_dim + (1 if learn_variance else 0)
self.harmonics_dim = harmonics_dim
self.t_emb_dim = t_emb_dim
self.s_emb_dim = s_emb_dim
self.hidden_dim = hidden_dim
self.joint_layers = joint_layers
self.zero_init = zero_init
self.clipping = clipping
self.gfn_clip = gfn_clip
self.pb_scale_range = pb_scale_range
self.log_var_range = log_var_range
self.learn_variance = learn_variance

assert (
self.s_emb_dim == self.t_emb_dim
), "Dimensionality of state embedding and time embedding should be the same!"

self.t_model = DiffusionPISTimeEncoding(
self.harmonics_dim, self.t_emb_dim, self.hidden_dim
)
self.s_model = DiffusionPISStateEncoding(self.s_dim, self.s_emb_dim)
self.joint_model = DiffusionPISJointPolicy(
self.s_emb_dim,
self.hidden_dim,
self.out_dim,
self.joint_layers,
self.zero_init,
)

def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor:
s = preprocessed_states[..., :-1]
t = preprocessed_states[..., -1]
s_emb = self.s_model(s)
t_emb = self.t_model(t)
out = self.joint_model(s_emb, t_emb)

if self.clipping:
out = torch.clamp(out, -self.gfn_clip, self.gfn_clip)

# Tanh-scale to stay near Brownian bridge; last dim (if present) is log-std corr.
drift_corr = torch.tanh(out[..., : self.s_dim]) * self.pb_scale_range
if self.learn_variance and out.shape[-1] == self.s_dim + 1:
log_std_corr = torch.tanh(out[..., [-1]]) * self.pb_scale_range
log_std_corr = torch.clamp(
log_std_corr, -self.log_var_range, self.log_var_range
)
out = torch.cat([drift_corr, log_std_corr], dim=-1)
else:
out = drift_corr

if torch.isnan(out).any():
raise ValueError("DiffusionPISGradNetBackward produced NaNs")

return out
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new DiffusionPISGradNetBackward class lacks dedicated test coverage. While it's used indirectly in the tutorial script, there are no unit tests that directly verify its behavior (e.g., output shapes, learned corrections, variance handling). Consider adding tests similar to those for DiffusionPISGradNetForward to ensure the backward module works correctly.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added test coverage.

Comment on lines +1375 to +1379
eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS
is_final_step = (t_curr + self.dt) >= (1.0 - eps)
# TODO: The old code followed this convention (below). I believe the change
# is slightly more correct, but I'd like to check this during review.
# (1.0 - t_curr) < self.dt * 1e-2 # Triggers when t_curr ≈ 1.0
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This TODO comment suggests uncertainty about the correctness of the exit condition change. TODOs requesting review in production code should be resolved before merging. Either verify the correctness and remove the TODO, or if there's genuine uncertainty, add a test to validate the behavior.

Suggested change
eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS
is_final_step = (t_curr + self.dt) >= (1.0 - eps)
# TODO: The old code followed this convention (below). I believe the change
# is slightly more correct, but I'd like to check this during review.
# (1.0 - t_curr) < self.dt * 1e-2 # Triggers when t_curr ≈ 1.0
# Note: this replaces an older heuristic `(1.0 - t_curr) < self.dt * 1e-2`,
# using the shared `_DIFFUSION_TERMINAL_TIME_EPS` tolerance for consistency.
eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS
is_final_step = (t_curr + self.dt) >= (1.0 - eps)

Copilot uses AI. Check for mistakes.
"""
assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1"
s_curr = states.tensor[:, :-1]
# s_curr = states.tensor[:, :-1]
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented-out line appears to be dead code that should be removed. If it's intended for reference, consider moving it to a comment explaining why the change was made rather than leaving commented code.

Suggested change
# s_curr = states.tensor[:, :-1]

Copilot uses AI. Check for mistakes.
if self.debug and torch.isnan(logpf_sum).any():
raise ValueError("NaNs in logpf_sum during MLE loss.")

# TODO: Use included loss reduction helpers.
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment suggests using included loss reduction helpers, but the current implementation uses a simple conditional. Either implement the TODO or remove it if the current approach is acceptable.

Suggested change
# TODO: Use included loss reduction helpers.

Copilot uses AI. Check for mistakes.
self.prior_pf = prior_pf
self.beta = torch.tensor(beta)
self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ))
self.debug = debug # TODO: to be passed to base classes.
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment about passing debug flag to base classes should be resolved. Either implement the propagation of the debug flag to parent classes or remove the TODO if it's not necessary for this feature.

Suggested change
self.debug = debug # TODO: to be passed to base classes.
self.debug = debug

Copilot uses AI. Check for mistakes.
josephdviviano and others added 10 commits December 17, 2025 22:13
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Configure codecov to not block CI while tracking coverage
Copy link
Contributor

Copilot AI commented Dec 18, 2025

@josephdviviano I've opened a new pull request, #459, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot AI and others added 2 commits December 18, 2025 04:05
Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com>
Refactor: Extract OUTPUT_DIR constant for visualization paths
# This matches the exit condition in DiffusionSampling.step() and ensures the
# sampled action is marked as an exit action (-inf) so trajectory masks align
# correctly in get_trajectory_pbs.
eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should revert change - am terminating one step too early.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants