-
Notifications
You must be signed in to change notification settings - Fork 55
Relative trajectory balance #457
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
ee17515
85190b0
6ecd0ca
08258ad
6ea54cd
a0666da
f771c2c
51ae63d
4aebf20
f685b86
e647877
06ccc10
bfbbc22
8f8cadb
bb2cb45
69efb35
dd2c4e4
c8eb351
63a4bd1
46807ed
b40655b
74b8a60
7deba0e
a2f5a6c
ee22f3c
f8d5c5b
ad457f6
f439a78
b141f3f
97a2453
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| coverage: | ||
| status: | ||
| project: | ||
| default: | ||
| # Set to informational only - will not block PRs | ||
| informational: true | ||
| patch: | ||
| default: | ||
| # Set to informational only - will not block PRs | ||
| informational: true | ||
|
|
||
| comment: | ||
| # Still show coverage comments on PRs | ||
| layout: "diff, flags, files" | ||
| behavior: default |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,3 +1,4 @@ | ||||||||||||||||||||||||||||||
| import math | ||||||||||||||||||||||||||||||
| from abc import ABC, abstractmethod | ||||||||||||||||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||||||||||||||
| from typing import Any, Callable, Dict, List, Optional, Protocol, cast, runtime_checkable | ||||||||||||||||||||||||||||||
|
|
@@ -26,6 +27,12 @@ | |||||||||||||||||||||||||||||
| "prod": torch.prod, | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Relative tolerance for detecting terminal time in diffusion estimators. | ||||||||||||||||||||||||||||||
| # Must match TERMINAL_TIME_EPS in gfn.gym.diffusion_sampling to ensure consistent | ||||||||||||||||||||||||||||||
| # exit action detection between the estimator and environment. TODO: we should handle this | ||||||||||||||||||||||||||||||
| # centrally somewhere. | ||||||||||||||||||||||||||||||
| _DIFFUSION_TERMINAL_TIME_EPS = 1e-2 | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class RolloutContext: | ||||||||||||||||||||||||||||||
| """Structured per‑rollout state owned by estimators. | ||||||||||||||||||||||||||||||
|
|
@@ -1290,6 +1297,7 @@ def __init__( | |||||||||||||||||||||||||||||
| pf_module: nn.Module, | ||||||||||||||||||||||||||||||
| sigma: float, | ||||||||||||||||||||||||||||||
| num_discretization_steps: int, | ||||||||||||||||||||||||||||||
| n_variance_outputs: int = 0, | ||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||
| """Initialize the PinnedBrownianMotionForward. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
@@ -1305,6 +1313,12 @@ def __init__( | |||||||||||||||||||||||||||||
| self.sigma = sigma | ||||||||||||||||||||||||||||||
| self.num_discretization_steps = num_discretization_steps | ||||||||||||||||||||||||||||||
| self.dt = 1.0 / self.num_discretization_steps | ||||||||||||||||||||||||||||||
| self.n_variance_outputs = n_variance_outputs | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def expected_output_dim(self) -> int: | ||||||||||||||||||||||||||||||
| # Drift (s_dim) plus optional variance outputs. | ||||||||||||||||||||||||||||||
| return self.s_dim + self.n_variance_outputs | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def forward(self, input: States) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| """Forward pass of the module. | ||||||||||||||||||||||||||||||
|
|
@@ -1329,7 +1343,6 @@ def to_probability_distribution( | |||||||||||||||||||||||||||||
| states: States, | ||||||||||||||||||||||||||||||
| module_output: torch.Tensor, | ||||||||||||||||||||||||||||||
| **policy_kwargs: Any, | ||||||||||||||||||||||||||||||
| # TODO: add epsilon-noisy exploration | ||||||||||||||||||||||||||||||
| ) -> IsotropicGaussian: | ||||||||||||||||||||||||||||||
| """Transform the output of the module into a IsotropicGaussian distribution, | ||||||||||||||||||||||||||||||
| which is the distribution of the next states under the pinned Brownian motion | ||||||||||||||||||||||||||||||
|
|
@@ -1339,24 +1352,75 @@ def to_probability_distribution( | |||||||||||||||||||||||||||||
| states: The states to use, states.tensor.shape = (*batch_shape, s_dim + 1). | ||||||||||||||||||||||||||||||
| module_output: The output of the module (actions), as a tensor of shape | ||||||||||||||||||||||||||||||
| (*batch_shape, s_dim). | ||||||||||||||||||||||||||||||
| **policy_kwargs: Keyword arguments to modify the distribution. | ||||||||||||||||||||||||||||||
| **policy_kwargs: Keyword arguments to modify the distribution. Supported | ||||||||||||||||||||||||||||||
| keys: | ||||||||||||||||||||||||||||||
| - exploration_std: Optional callable or float controlling extra | ||||||||||||||||||||||||||||||
| exploration noise on top of the base diffusion std. The callable | ||||||||||||||||||||||||||||||
| should accept an integer step index and return a non-negative | ||||||||||||||||||||||||||||||
| standard deviation in state space. When provided, the extra noise | ||||||||||||||||||||||||||||||
| is combined in variance-space (logaddexp) with the base diffusion | ||||||||||||||||||||||||||||||
| variance; non-positive exploration is ignored. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||
| A IsotropicGaussian distribution (distribution of the next states) | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
| assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1" | ||||||||||||||||||||||||||||||
| s_curr = states.tensor[:, :-1] | ||||||||||||||||||||||||||||||
| # s_curr = states.tensor[:, :-1] | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| # s_curr = states.tensor[:, :-1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should revert change - am terminating one step too early.
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This TODO comment suggests uncertainty about the correctness of the exit condition change. TODOs requesting review in production code should be resolved before merging. Either verify the correctness and remove the TODO, or if there's genuine uncertainty, add a test to validate the behavior.
| eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS | |
| is_final_step = (t_curr + self.dt) >= (1.0 - eps) | |
| # TODO: The old code followed this convention (below). I believe the change | |
| # is slightly more correct, but I'd like to check this during review. | |
| # (1.0 - t_curr) < self.dt * 1e-2 # Triggers when t_curr ≈ 1.0 | |
| # Note: this replaces an older heuristic `(1.0 - t_curr) < self.dt * 1e-2`, | |
| # using the shared `_DIFFUSION_TERMINAL_TIME_EPS` tolerance for consistency. | |
| eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS | |
| is_final_step = (t_curr + self.dt) >= (1.0 - eps) |
Copilot
AI
Dec 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking the log of fwd_std can fail if fwd_std contains zeros or negative values. When n_variance_outputs is 0, fwd_std could potentially be very small or zero. Additionally, when exploration_std_t is 0, the clamp to 1e-12 on line 1415 only protects the extra_log_var computation but not the base_log_var. Consider adding a clamp_min to fwd_std before taking its log, or handle the exploration_std_t == 0 case separately to avoid unnecessary log operations.
| # Combine base diffusion variance σ_base^2 with exploration variance σ_expl^2: | |
| # σ_combined = sqrt(σ_base^2 + σ_expl^2). torch.compile friendly. | |
| base_log_var = 2 * fwd_std.log() # log(σ_base^2) | |
| # If there is no positive exploration noise, keep the base diffusion std. | |
| # This avoids unnecessary log operations and potential log(0) issues. | |
| if exploration_std_t.eq(0).all(): | |
| return IsotropicGaussian(fwd_mean, fwd_std) | |
| # Combine base diffusion variance σ_base^2 with exploration variance σ_expl^2: | |
| # σ_combined = sqrt(σ_base^2 + σ_expl^2). torch.compile friendly. | |
| # Clamp fwd_std to a small positive value before taking the log to avoid | |
| # numerical issues when fwd_std is extremely small or zero. | |
| safe_fwd_std = fwd_std.clamp_min(1e-12) | |
| base_log_var = 2 * safe_fwd_std.log() # log(σ_base^2) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,11 @@ | |
| from gfn.estimators import Estimator | ||
| from gfn.samplers import Sampler | ||
| from gfn.states import States | ||
| from gfn.utils.prob_calculations import get_trajectory_pfs_and_pbs | ||
| from gfn.utils.prob_calculations import ( | ||
| get_trajectory_pbs, | ||
| get_trajectory_pfs, | ||
| get_trajectory_pfs_and_pbs, | ||
| ) | ||
|
|
||
| TrainingSampleType = TypeVar("TrainingSampleType", bound=Container) | ||
|
|
||
|
|
@@ -343,6 +347,32 @@ def get_pfs_and_pbs( | |
| recalculate_all_logprobs, | ||
| ) | ||
|
|
||
| def trajectory_log_probs_forward( | ||
| self, | ||
| trajectories: Trajectories, | ||
| fill_value: float = 0.0, | ||
| recalculate_all_logprobs: bool = True, | ||
| ) -> torch.Tensor: | ||
| """Evaluates forward logprobs only for each trajectory in the batch.""" | ||
| return get_trajectory_pfs( | ||
| self.pf, | ||
| trajectories, | ||
| fill_value=fill_value, | ||
| recalculate_all_logprobs=recalculate_all_logprobs, | ||
| ) | ||
|
|
||
| def trajectory_log_probs_backward( | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leaving these here because they might come in handy, but I don't think they're actually needed right now in this implementation. |
||
| self, | ||
| trajectories: Trajectories, | ||
| fill_value: float = 0.0, | ||
| ) -> torch.Tensor: | ||
| """Evaluates backward logprobs only for each trajectory in the batch.""" | ||
| return get_trajectory_pbs( | ||
| self.pb, | ||
| trajectories, | ||
| fill_value=fill_value, | ||
| ) | ||
|
|
||
| def get_scores( | ||
| self, | ||
| trajectories: Trajectories, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment suggests this constant should be handled centrally across multiple files (diffusion_sampling.py, estimators.py, and mle.py). Having the same magic value duplicated in three places is a maintenance risk. Consider creating a shared constants module or config file for cross-module constants like this.