-
Notifications
You must be signed in to change notification settings - Fork 55
Relative trajectory balance #457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #457 +/- ##
==========================================
+ Coverage 74.23% 74.33% +0.09%
==========================================
Files 47 48 +1
Lines 6805 7112 +307
Branches 800 830 +30
==========================================
+ Hits 5052 5287 +235
- Misses 1449 1503 +54
- Partials 304 322 +18
🚀 New features to boost your workflow:
|
src/gfn/estimators.py
Outdated
| s_curr, | ||
| s_curr * self.dt / t_curr, | ||
| torch.zeros_like(s_curr), | ||
| s_curr * (1.0 - self.dt / t_curr), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hyeok9855 this might have been the cause of the problem you had before in your code (learning slower), worth checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hyeok9855 actually- this was my bug!!
| recalculate_all_logprobs=recalculate_all_logprobs, | ||
| ) | ||
|
|
||
| def trajectory_log_probs_backward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
leaving these here because they might come in handy, but I don't think they're actually needed right now in this implementation.
…gfn into relative_trajectory_balance
|
@copilot how can I prevent codecov from failing CI? I want it to track coverage, but not block merge requests. |
|
@josephdviviano I've opened a new pull request, #458, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds Relative Trajectory Balance (RTB) support for fine-tuning diffusion-based GFlowNets from a prior to a posterior distribution. It introduces a two-stage training pipeline: (1) maximum likelihood estimation (MLE) pre-training of a prior diffusion model using either a fixed Brownian bridge or learned backward policy, and (2) RTB-based fine-tuning that adapts the prior to match a posterior distribution weighted by a reward function.
Key changes include:
- Implementation of
RelativeTrajectoryBalanceGFlowNetfor fine-tuning with the RTB loss - Implementation of
MLEDiffusiontrainer for prior pre-training via maximum likelihood - Support for learned variance in both forward and backward diffusion policies
- New Gaussian mixture target distributions for RTB demonstrations (25-mode prior, 9-mode posterior)
- Enhanced diffusion trajectory mask alignment with consistent terminal time detection
Reviewed changes
Copilot reviewed 15 out of 16 changed files in this pull request and generated 19 comments.
Show a summary per file
| File | Description |
|---|---|
src/gfn/gflownet/trajectory_balance.py |
Adds RelativeTrajectoryBalanceGFlowNet class implementing the RTB objective |
src/gfn/gflownet/mle.py |
New MLE trainer for diffusion pre-training with backward sampling |
src/gfn/estimators.py |
Adds learned variance support and exploration noise handling to forward/backward estimators |
src/gfn/utils/modules.py |
Implements DiffusionPISGradNetBackward and adds learned variance to forward module |
src/gfn/gym/diffusion_sampling.py |
Adds Grid25GaussianMixture and Posterior9of25GaussianMixture targets; improves terminal state detection |
src/gfn/gym/helpers/diffusion_utils.py |
Adds max_n_samples parameter to visualization utility |
src/gfn/gflownet/base.py |
Adds helper methods for computing forward/backward trajectory log-probs separately |
src/gfn/gflownet/__init__.py |
Exports RelativeTrajectoryBalanceGFlowNet |
src/gfn/env.py |
Minor comment clarification on boolean masking behavior |
tutorials/examples/train_diffusion_rtb.py |
Complete end-to-end tutorial demonstrating MLE pre-training and RTB fine-tuning |
tutorials/examples/output/.gitignore |
Ignores generated checkpoint and visualization files |
testing/test_rtb.py |
Unit tests for RTB loss computation and gradient flow |
testing/test_environments.py |
Test verifying diffusion trajectory mask alignment |
testing/gym/test_diffusion_sampling_rtb.py |
Tests for new GMM target distributions |
testing/gflownet/test_mle_diffusion.py |
Comprehensive tests for MLE diffusion trainer |
.flake8 |
Minor formatting (trailing newline) |
Comments suppressed due to low confidence (2)
src/gfn/gym/diffusion_sampling.py:69
- This call to BaseTarget.sample in an initialization method is overridden by SimpleGaussianMixture.sample.
This call to BaseTarget.sample in an initialization method is overridden by Grid25GaussianMixture.sample.
This call to BaseTarget.sample in an initialization method is overridden by Posterior9of25GaussianMixture.sample.
This call to BaseTarget.sample in an initialization method is overridden by Funnel.sample.
This call to BaseTarget.sample in an initialization method is overridden by ManyWell.sample.
self.gt_xs = self.sample(n_gt_xs, seed)
src/gfn/gym/diffusion_sampling.py:70
- This call to BaseTarget.log_reward in an initialization method is overridden by SimpleGaussianMixture.log_reward.
This call to BaseTarget.log_reward in an initialization method is overridden by Grid25GaussianMixture.log_reward.
This call to BaseTarget.log_reward in an initialization method is overridden by Posterior9of25GaussianMixture.log_reward.
This call to BaseTarget.log_reward in an initialization method is overridden by Funnel.log_reward.
This call to BaseTarget.log_reward in an initialization method is overridden by ManyWell.log_reward.
self.gt_xs_log_rewards = self.log_reward(self.gt_xs)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/gfn/gym/diffusion_sampling.py
Outdated
| os.makedirs("output", exist_ok=True) | ||
| plt.savefig(f"output/{prefix}simple_gmm.png") |
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output directory is changed from "viz" to "output" across multiple target classes. While this change is consistent, it would be better to define the output directory as a constant or configuration parameter at the module level rather than hardcoding it in multiple places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot open a new pull request to apply changes based on this feedback
| class DiffusionPISGradNetBackward(nn.Module): | ||
| """Learnable backward correction module (PIS-style) for diffusion. | ||
| Produces mean and optional log-std corrections that are tanh-scaled by | ||
| `pb_scale_range` to stay close to the analytic Brownian bridge. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| s_dim: int, | ||
| harmonics_dim: int = 64, | ||
| t_emb_dim: int = 64, | ||
| s_emb_dim: int = 64, | ||
| hidden_dim: int = 64, | ||
| joint_layers: int = 2, | ||
| zero_init: bool = False, | ||
| clipping: bool = False, | ||
| gfn_clip: float = 1e4, | ||
| pb_scale_range: float = 0.1, | ||
| log_var_range: float = 4.0, | ||
| learn_variance: bool = True, | ||
| ) -> None: | ||
| super().__init__() | ||
| self.s_dim = s_dim | ||
| self.out_dim = s_dim + (1 if learn_variance else 0) | ||
| self.harmonics_dim = harmonics_dim | ||
| self.t_emb_dim = t_emb_dim | ||
| self.s_emb_dim = s_emb_dim | ||
| self.hidden_dim = hidden_dim | ||
| self.joint_layers = joint_layers | ||
| self.zero_init = zero_init | ||
| self.clipping = clipping | ||
| self.gfn_clip = gfn_clip | ||
| self.pb_scale_range = pb_scale_range | ||
| self.log_var_range = log_var_range | ||
| self.learn_variance = learn_variance | ||
|
|
||
| assert ( | ||
| self.s_emb_dim == self.t_emb_dim | ||
| ), "Dimensionality of state embedding and time embedding should be the same!" | ||
|
|
||
| self.t_model = DiffusionPISTimeEncoding( | ||
| self.harmonics_dim, self.t_emb_dim, self.hidden_dim | ||
| ) | ||
| self.s_model = DiffusionPISStateEncoding(self.s_dim, self.s_emb_dim) | ||
| self.joint_model = DiffusionPISJointPolicy( | ||
| self.s_emb_dim, | ||
| self.hidden_dim, | ||
| self.out_dim, | ||
| self.joint_layers, | ||
| self.zero_init, | ||
| ) | ||
|
|
||
| def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: | ||
| s = preprocessed_states[..., :-1] | ||
| t = preprocessed_states[..., -1] | ||
| s_emb = self.s_model(s) | ||
| t_emb = self.t_model(t) | ||
| out = self.joint_model(s_emb, t_emb) | ||
|
|
||
| if self.clipping: | ||
| out = torch.clamp(out, -self.gfn_clip, self.gfn_clip) | ||
|
|
||
| # Tanh-scale to stay near Brownian bridge; last dim (if present) is log-std corr. | ||
| drift_corr = torch.tanh(out[..., : self.s_dim]) * self.pb_scale_range | ||
| if self.learn_variance and out.shape[-1] == self.s_dim + 1: | ||
| log_std_corr = torch.tanh(out[..., [-1]]) * self.pb_scale_range | ||
| log_std_corr = torch.clamp( | ||
| log_std_corr, -self.log_var_range, self.log_var_range | ||
| ) | ||
| out = torch.cat([drift_corr, log_std_corr], dim=-1) | ||
| else: | ||
| out = drift_corr | ||
|
|
||
| if torch.isnan(out).any(): | ||
| raise ValueError("DiffusionPISGradNetBackward produced NaNs") | ||
|
|
||
| return out |
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new DiffusionPISGradNetBackward class lacks dedicated test coverage. While it's used indirectly in the tutorial script, there are no unit tests that directly verify its behavior (e.g., output shapes, learned corrections, variance handling). Consider adding tests similar to those for DiffusionPISGradNetForward to ensure the backward module works correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added test coverage.
| eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS | ||
| is_final_step = (t_curr + self.dt) >= (1.0 - eps) | ||
| # TODO: The old code followed this convention (below). I believe the change | ||
| # is slightly more correct, but I'd like to check this during review. | ||
| # (1.0 - t_curr) < self.dt * 1e-2 # Triggers when t_curr ≈ 1.0 |
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This TODO comment suggests uncertainty about the correctness of the exit condition change. TODOs requesting review in production code should be resolved before merging. Either verify the correctness and remove the TODO, or if there's genuine uncertainty, add a test to validate the behavior.
| eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS | |
| is_final_step = (t_curr + self.dt) >= (1.0 - eps) | |
| # TODO: The old code followed this convention (below). I believe the change | |
| # is slightly more correct, but I'd like to check this during review. | |
| # (1.0 - t_curr) < self.dt * 1e-2 # Triggers when t_curr ≈ 1.0 | |
| # Note: this replaces an older heuristic `(1.0 - t_curr) < self.dt * 1e-2`, | |
| # using the shared `_DIFFUSION_TERMINAL_TIME_EPS` tolerance for consistency. | |
| eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS | |
| is_final_step = (t_curr + self.dt) >= (1.0 - eps) |
| """ | ||
| assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1" | ||
| s_curr = states.tensor[:, :-1] | ||
| # s_curr = states.tensor[:, :-1] |
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The commented-out line appears to be dead code that should be removed. If it's intended for reference, consider moving it to a comment explaining why the change was made rather than leaving commented code.
| # s_curr = states.tensor[:, :-1] |
| if self.debug and torch.isnan(logpf_sum).any(): | ||
| raise ValueError("NaNs in logpf_sum during MLE loss.") | ||
|
|
||
| # TODO: Use included loss reduction helpers. |
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment suggests using included loss reduction helpers, but the current implementation uses a simple conditional. Either implement the TODO or remove it if the current approach is acceptable.
| # TODO: Use included loss reduction helpers. |
| self.prior_pf = prior_pf | ||
| self.beta = torch.tensor(beta) | ||
| self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) | ||
| self.debug = debug # TODO: to be passed to base classes. |
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment about passing debug flag to base classes should be resolved. Either implement the propagation of the debug flag to parent classes or remove the TODO if it's not necessary for this feature.
| self.debug = debug # TODO: to be passed to base classes. | |
| self.debug = debug |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Configure codecov to not block CI while tracking coverage
|
@josephdviviano I've opened a new pull request, #459, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: josephdviviano <4142570+josephdviviano@users.noreply.github.com>
Refactor: Extract OUTPUT_DIR constant for visualization paths
| # This matches the exit condition in DiffusionSampling.step() and ensures the | ||
| # sampled action is marked as an exit action (-inf) so trajectory masks align | ||
| # correctly in get_trajectory_pbs. | ||
| eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should revert change - am terminating one step too early.
Description
pfusing either a learned backward or brownian bridge from provided samples using maximum likelihood,DiffusionMLE.RelativeTrajectoryBalanceGFlowNet.train_diffusion_rtb.py.