Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ee17515
rtb prototype
josephdviviano Dec 15, 2025
85190b0
sync of RTB prototype
josephdviviano Dec 16, 2025
6ecd0ca
ignore outputs
josephdviviano Dec 16, 2025
08258ad
rtb finetune first pass
josephdviviano Dec 16, 2025
6ea54cd
loss bugfixes
josephdviviano Dec 16, 2025
a0666da
prior learning works, still working on finetune step
josephdviviano Dec 16, 2025
f771c2c
RTB is working - next step is to factorize
josephdviviano Dec 16, 2025
51ae63d
refactored MLE pipeline
josephdviviano Dec 17, 2025
4aebf20
added the MLE trainer
josephdviviano Dec 17, 2025
f685b86
cleaned up training script
josephdviviano Dec 17, 2025
e647877
shrunk script for clarity
josephdviviano Dec 17, 2025
06ccc10
no change
josephdviviano Dec 17, 2025
bfbbc22
fixed backward bug
josephdviviano Dec 18, 2025
8f8cadb
Merge branch 'master' into relative_trajectory_balance
josephdviviano Dec 18, 2025
bb2cb45
changed back to old line lengths
josephdviviano Dec 18, 2025
69efb35
Merge branch 'relative_trajectory_balance' of github.com:GFNOrg/torch…
josephdviviano Dec 18, 2025
dd2c4e4
Initial plan
Copilot Dec 18, 2025
c8eb351
Configure codecov to not block CI/merge requests
Copilot Dec 18, 2025
63a4bd1
Update src/gfn/estimators.py
josephdviviano Dec 18, 2025
46807ed
Update tutorials/examples/train_diffusion_rtb.py
josephdviviano Dec 18, 2025
b40655b
Update tutorials/examples/train_diffusion_rtb.py
josephdviviano Dec 18, 2025
74b8a60
Update tutorials/examples/train_diffusion_rtb.py
josephdviviano Dec 18, 2025
7deba0e
Update src/gfn/utils/modules.py
josephdviviano Dec 18, 2025
a2f5a6c
Update src/gfn/gflownet/trajectory_balance.py
josephdviviano Dec 18, 2025
ee22f3c
Update src/gfn/gym/diffusion_sampling.py
josephdviviano Dec 18, 2025
f8d5c5b
Merge pull request #458 from GFNOrg/copilot/sub-pr-457
josephdviviano Dec 18, 2025
ad457f6
added diffusion tests
josephdviviano Dec 18, 2025
f439a78
Initial plan
Copilot Dec 18, 2025
b141f3f
Refactor: Define OUTPUT_DIR constant for visualization output directory
Copilot Dec 18, 2025
97a2453
Merge pull request #459 from GFNOrg/copilot/sub-pr-457-again
josephdviviano Dec 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
coverage:
status:
project:
default:
# Set to informational only - will not block PRs
informational: true
patch:
default:
# Set to informational only - will not block PRs
informational: true

comment:
# Still show coverage comments on PRs
layout: "diff, flags, files"
behavior: default
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
ignore = E203, E266, E501, W503, F403, F401, F821
max-line-length = 89
max-complexity = 18
select = B,C,E,F,W,T4,B9
select = B,C,E,F,W,T4,B9
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ jobs:
uses: codecov/codecov-action@v5
with:
files: coverage.xml
fail_ci_if_error: true
fail_ci_if_error: false
token: ${{ secrets.CODECOV_TOKEN }}
6 changes: 4 additions & 2 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,17 @@ def _step(self, states: States, actions: Actions) -> States:
# We only step on states that are not sink states.
# Note that exit actions directly set the states to the sink state, so they
# are not included in the valid_states_idx.
new_valid_states_idx = valid_states_idx & ~actions.is_exit
new_valid_states_idx = valid_states_idx & ~actions.is_exit # boolean mask.

# IMPORTANT: .clone() is used to ensure that the new states are a
# distinct object from the old states. This is important for the sampler to
# work correctly when building the trajectories. If you want to override this
# method in your custom environment, you must ensure that the `new_states`
# returned is a distinct object from the submitted states.
not_done_states = states[new_valid_states_idx].clone()
not_done_actions = actions[new_valid_states_idx]
not_done_actions = actions[
new_valid_states_idx
] # NOTE: boolean indexing creates a copy!

not_done_states = self.step(not_done_states, not_done_actions)
assert isinstance(
Expand Down
130 changes: 109 additions & 21 deletions src/gfn/estimators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Protocol, cast, runtime_checkable
Expand Down Expand Up @@ -26,6 +27,12 @@
"prod": torch.prod,
}

# Relative tolerance for detecting terminal time in diffusion estimators.
# Must match TERMINAL_TIME_EPS in gfn.gym.diffusion_sampling to ensure consistent
# exit action detection between the estimator and environment. TODO: we should handle this
# centrally somewhere.
_DIFFUSION_TERMINAL_TIME_EPS = 1e-2
Comment on lines +30 to +34
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment suggests this constant should be handled centrally across multiple files (diffusion_sampling.py, estimators.py, and mle.py). Having the same magic value duplicated in three places is a maintenance risk. Consider creating a shared constants module or config file for cross-module constants like this.

Copilot uses AI. Check for mistakes.


class RolloutContext:
"""Structured per‑rollout state owned by estimators.
Expand Down Expand Up @@ -1290,6 +1297,7 @@ def __init__(
pf_module: nn.Module,
sigma: float,
num_discretization_steps: int,
n_variance_outputs: int = 0,
):
"""Initialize the PinnedBrownianMotionForward.

Expand All @@ -1305,6 +1313,12 @@ def __init__(
self.sigma = sigma
self.num_discretization_steps = num_discretization_steps
self.dt = 1.0 / self.num_discretization_steps
self.n_variance_outputs = n_variance_outputs

@property
def expected_output_dim(self) -> int:
# Drift (s_dim) plus optional variance outputs.
return self.s_dim + self.n_variance_outputs

def forward(self, input: States) -> torch.Tensor:
"""Forward pass of the module.
Expand All @@ -1329,7 +1343,6 @@ def to_probability_distribution(
states: States,
module_output: torch.Tensor,
**policy_kwargs: Any,
# TODO: add epsilon-noisy exploration
) -> IsotropicGaussian:
"""Transform the output of the module into a IsotropicGaussian distribution,
which is the distribution of the next states under the pinned Brownian motion
Expand All @@ -1339,24 +1352,75 @@ def to_probability_distribution(
states: The states to use, states.tensor.shape = (*batch_shape, s_dim + 1).
module_output: The output of the module (actions), as a tensor of shape
(*batch_shape, s_dim).
**policy_kwargs: Keyword arguments to modify the distribution.
**policy_kwargs: Keyword arguments to modify the distribution. Supported
keys:
- exploration_std: Optional callable or float controlling extra
exploration noise on top of the base diffusion std. The callable
should accept an integer step index and return a non-negative
standard deviation in state space. When provided, the extra noise
is combined in variance-space (logaddexp) with the base diffusion
variance; non-positive exploration is ignored.

Returns:
A IsotropicGaussian distribution (distribution of the next states)
"""
assert len(states.batch_shape) == 1, "States must have a batch_shape of length 1"
s_curr = states.tensor[:, :-1]
# s_curr = states.tensor[:, :-1]
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented-out line appears to be dead code that should be removed. If it's intended for reference, consider moving it to a comment explaining why the change was made rather than leaving commented code.

Suggested change
# s_curr = states.tensor[:, :-1]

Copilot uses AI. Check for mistakes.
t_curr = states.tensor[:, [-1]]

# Check if the NEXT step would reach terminal time, not if we're already there.
# This matches the exit condition in DiffusionSampling.step() and ensures the
# sampled action is marked as an exit action (-inf) so trajectory masks align
# correctly in get_trajectory_pbs.
eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should revert change - am terminating one step too early.

is_final_step = (t_curr + self.dt) >= (1.0 - eps)
# TODO: The old code followed this convention (below). I believe the change
# is slightly more correct, but I'd like to check this during review.
# (1.0 - t_curr) < self.dt * 1e-2 # Triggers when t_curr ≈ 1.0
Comment on lines +1375 to +1379
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This TODO comment suggests uncertainty about the correctness of the exit condition change. TODOs requesting review in production code should be resolved before merging. Either verify the correctness and remove the TODO, or if there's genuine uncertainty, add a test to validate the behavior.

Suggested change
eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS
is_final_step = (t_curr + self.dt) >= (1.0 - eps)
# TODO: The old code followed this convention (below). I believe the change
# is slightly more correct, but I'd like to check this during review.
# (1.0 - t_curr) < self.dt * 1e-2 # Triggers when t_curr ≈ 1.0
# Note: this replaces an older heuristic `(1.0 - t_curr) < self.dt * 1e-2`,
# using the shared `_DIFFUSION_TERMINAL_TIME_EPS` tolerance for consistency.
eps = self.dt * _DIFFUSION_TERMINAL_TIME_EPS
is_final_step = (t_curr + self.dt) >= (1.0 - eps)

Copilot uses AI. Check for mistakes.

module_output = torch.where(
(1.0 - t_curr) < self.dt * 1e-2, # sf case; when t_curr is 1.0
torch.full_like(s_curr, -float("inf")), # This is the exit action
is_final_step,
torch.full_like(module_output, -float("inf")), # This is the exit action
module_output,
)

fwd_mean = self.dt * module_output
fwd_std = torch.tensor(self.sigma * self.dt**0.5, device=fwd_mean.device)
fwd_std = fwd_std.repeat(fwd_mean.shape[0], 1)
drift = module_output[..., : self.s_dim]
if self.n_variance_outputs > 0:
var_part = module_output[..., self.s_dim :]
# Reduce extra variance dims to a single scalar (isotropic for now).
log_std = var_part.mean(dim=-1, keepdim=True)
fwd_std = torch.exp(log_std) * math.sqrt(self.dt)
else:
fwd_std = torch.tensor(self.sigma * self.dt**0.5, device=drift.device)
fwd_std = fwd_std.repeat(drift.shape[0], 1)

# Match reference behavior: scale diffusion noise (not drift) by t_scale if present.
t_scale_factor = getattr(self.module, "t_scale", 1.0)
if t_scale_factor != 1.0:
fwd_std = fwd_std * math.sqrt(t_scale_factor)

fwd_mean = self.dt * drift

# Optional exploration noise: combine variances (quadrature/logaddexp).
exploration_std = policy_kwargs.pop("exploration_std", None)
exploration_std_t = torch.as_tensor(
exploration_std if exploration_std is not None else 0.0,
device=fwd_std.device,
dtype=fwd_std.dtype,
).clamp(min=0.0)

# Combine base diffusion variance σ_base^2 with exploration variance σ_expl^2:
# σ_combined = sqrt(σ_base^2 + σ_expl^2). torch.compile friendly.
base_log_var = 2 * fwd_std.log() # log(σ_base^2)
Comment on lines +1412 to +1414
Copy link

Copilot AI Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking the log of fwd_std can fail if fwd_std contains zeros or negative values. When n_variance_outputs is 0, fwd_std could potentially be very small or zero. Additionally, when exploration_std_t is 0, the clamp to 1e-12 on line 1415 only protects the extra_log_var computation but not the base_log_var. Consider adding a clamp_min to fwd_std before taking its log, or handle the exploration_std_t == 0 case separately to avoid unnecessary log operations.

Suggested change
# Combine base diffusion variance σ_base^2 with exploration variance σ_expl^2:
# σ_combined = sqrt(σ_base^2 + σ_expl^2). torch.compile friendly.
base_log_var = 2 * fwd_std.log() # log(σ_base^2)
# If there is no positive exploration noise, keep the base diffusion std.
# This avoids unnecessary log operations and potential log(0) issues.
if exploration_std_t.eq(0).all():
return IsotropicGaussian(fwd_mean, fwd_std)
# Combine base diffusion variance σ_base^2 with exploration variance σ_expl^2:
# σ_combined = sqrt(σ_base^2 + σ_expl^2). torch.compile friendly.
# Clamp fwd_std to a small positive value before taking the log to avoid
# numerical issues when fwd_std is extremely small or zero.
safe_fwd_std = fwd_std.clamp_min(1e-12)
base_log_var = 2 * safe_fwd_std.log() # log(σ_base^2)

Copilot uses AI. Check for mistakes.
extra_log_var = 2 * exploration_std_t.clamp(min=1e-12).log() # log(σ_expl^2)
extra_log_var_tensor = extra_log_var.expand_as(base_log_var)
combined_log_var = torch.logaddexp(base_log_var, extra_log_var_tensor)
fwd_std = torch.where(
exploration_std_t > 0,
torch.exp(0.5 * combined_log_var),
fwd_std,
)

return IsotropicGaussian(fwd_mean, fwd_std)


Expand All @@ -1367,30 +1431,34 @@ def __init__(
pb_module: nn.Module,
sigma: float,
num_discretization_steps: int,
n_variance_outputs: int = 0,
pb_scale_range: float = 0.1,
):
"""Initialize the PinnedBrownianMotionForward.
"""Initialize the PinnedBrownianMotionBackward.

Args:
s_dim: The dimension of the states.
pb_module: The neural network module to use for the backward policy.
sigma: The diffusion coefficient parameter for the pinned Brownian motion.
num_discretization_steps: The number of discretization steps.
n_variance_outputs: Number of variance outputs (0=fixed, 1=learned corr).
pb_scale_range: Scaling applied to learned corrections (tanh-bounded).
"""
super().__init__(s_dim=s_dim, module=pb_module, is_backward=True)

# Pinned Brownian Motion related
self.sigma = sigma
self.dt = 1.0 / num_discretization_steps
self.n_variance_outputs = n_variance_outputs
self.pb_scale_range = pb_scale_range

def forward(self, input: States) -> torch.Tensor:
"""Forward pass of the module.

Args:
input: The input to the module as states.
@property
def expected_output_dim(self) -> int:
# Drift correction (s_dim) plus optional variance correction outputs.
return self.s_dim + self.n_variance_outputs

Returns:
The output of the module, as a tensor of shape (*batch_shape, output_dim).
"""
def forward(self, input: States) -> torch.Tensor:
"""Forward pass of the module."""
out = self.module(self.preprocessor(input))

if self.expected_output_dim is not None:
Expand All @@ -1411,6 +1479,7 @@ def to_probability_distribution(
which is the distribution of the previous states under the pinned Brownian motion
process, possibly controlled by the output of the backward module. If the module
is a fixed backward module, the `module_output` is a zero vector (no control).
Includes optional learned corrections.

Args:
states: The states to use, states.tensor.shape = (*batch_shape, s_dim + 1).
Expand All @@ -1426,14 +1495,33 @@ def to_probability_distribution(
t_curr = states.tensor[:, [-1]] # shape: (*batch_shape,)

is_s0 = (t_curr - self.dt) < self.dt * 1e-2 # s0 case; when t_curr - dt is 0.0
bwd_mean = torch.where(
# Analytic Brownian bridge base
# Brownian bridge mean toward 0 at t=0:
# E[s_{t-dt} | s_t] = s_t * (1 - dt / t) and collapses to 0 at the start.
# Here, we calculate the *action* which moves the state in expectation toward 0
# at t=0, so we scale s_curr by our distance to t=0.
base_mean = torch.where(
is_s0,
s_curr,
s_curr * self.dt / t_curr,
torch.zeros_like(s_curr),
s_curr
* self.dt
/ t_curr, # s_curr (batch, s_dim), t_curr (batch, 1), dt is scalar.
)
bwd_std = torch.where(
base_std = torch.where(
is_s0,
torch.zeros_like(t_curr),
self.sigma * (self.dt * (t_curr - self.dt) / t_curr).sqrt(),
)

# Optional learned corrections (tanh-bounded); when n_variance_outputs==0, only mean corr.
mean_corr = module_output[..., : self.s_dim] * self.pb_scale_range
if self.n_variance_outputs > 0 and module_output.shape[-1] >= self.s_dim + 1:
log_std_corr = module_output[..., [-1]] * self.pb_scale_range
corr_std = torch.exp(log_std_corr)
else:
corr_std = torch.zeros_like(base_std)

bwd_mean = base_mean + mean_corr
bwd_std = (base_std**2 + corr_std**2).sqrt()

return IsotropicGaussian(bwd_mean, bwd_std)
7 changes: 6 additions & 1 deletion src/gfn/gflownet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from .detailed_balance import DBGFlowNet, ModifiedDBGFlowNet
from .flow_matching import FMGFlowNet
from .sub_trajectory_balance import SubTBGFlowNet
from .trajectory_balance import LogPartitionVarianceGFlowNet, TBGFlowNet
from .trajectory_balance import (
LogPartitionVarianceGFlowNet,
RelativeTrajectoryBalanceGFlowNet,
TBGFlowNet,
)

__all__ = [
"GFlowNet",
Expand All @@ -13,5 +17,6 @@
"FMGFlowNet",
"SubTBGFlowNet",
"LogPartitionVarianceGFlowNet",
"RelativeTrajectoryBalanceGFlowNet",
"TBGFlowNet",
]
32 changes: 31 additions & 1 deletion src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
from gfn.estimators import Estimator
from gfn.samplers import Sampler
from gfn.states import States
from gfn.utils.prob_calculations import get_trajectory_pfs_and_pbs
from gfn.utils.prob_calculations import (
get_trajectory_pbs,
get_trajectory_pfs,
get_trajectory_pfs_and_pbs,
)

TrainingSampleType = TypeVar("TrainingSampleType", bound=Container)

Expand Down Expand Up @@ -343,6 +347,32 @@ def get_pfs_and_pbs(
recalculate_all_logprobs,
)

def trajectory_log_probs_forward(
self,
trajectories: Trajectories,
fill_value: float = 0.0,
recalculate_all_logprobs: bool = True,
) -> torch.Tensor:
"""Evaluates forward logprobs only for each trajectory in the batch."""
return get_trajectory_pfs(
self.pf,
trajectories,
fill_value=fill_value,
recalculate_all_logprobs=recalculate_all_logprobs,
)

def trajectory_log_probs_backward(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leaving these here because they might come in handy, but I don't think they're actually needed right now in this implementation.

self,
trajectories: Trajectories,
fill_value: float = 0.0,
) -> torch.Tensor:
"""Evaluates backward logprobs only for each trajectory in the batch."""
return get_trajectory_pbs(
self.pb,
trajectories,
fill_value=fill_value,
)

def get_scores(
self,
trajectories: Trajectories,
Expand Down
Loading