Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/guides/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ GFlowNets can be trained with different losses, each of which requires a differe
Currently, the implemented losses are:

- Flow Matching: This is the original loss, and is mostly made available for completeness. It is slow to compute this loss, and also hard to optimize, so is generally not recommended for any significantly hard to learn problem.
- Detailed Balance (and it's modified variant).
- Detailed Balance (and its modified variant). In forward-looking mode, rewards must be defined on edges; the current implementation treats the edge reward as the difference between the successor and current state rewards, so only enable this when that matches your environment.
- Trajectory Balance
- Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined [here](https://www.semanticscholar.org/reader/f2c32fe3f7f3e2e9d36d833e32ec55fc93f900f5). Other strategies exist and are implemented [here](https://github.com/gfnorg/torchgfn/tree/master/src/gfn/losses/sub_trajectory_balance.py).
- Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined [here](https://www.semanticscholar.org/reader/f2c32fe3f7f3e2e9d36d833e32ec55fc93f900f5). Other strategies exist and are implemented [here](https://github.com/gfnorg/torchgfn/tree/master/src/gfn/losses/sub_trajectory_balance.py). When using geometric-based weighting, the 'mean' reduction is not supported; requests for a mean reduction are coerced to a sum (a warning is emitted when debug is enabled).
- Log Partition Variance loss. Introduced [here](https://arxiv.org/abs/2302.05446)
54 changes: 40 additions & 14 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import warnings
from abc import ABC, abstractmethod
from typing import Any, Generic, Tuple, TypeVar
Expand Down Expand Up @@ -48,6 +47,16 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]):

log_reward_clip_min = float("-inf") # Default off.

def __init__(self, debug: bool = False) -> None:
"""Initialize shared GFlowNet state.

Args:
debug: If True, keep runtime safety checks and warnings active. Set False
in compiled hot paths to avoid graph breaks; use True in tests/debugging.
"""
super().__init__()
self.debug = debug

@abstractmethod
def sample_trajectories(
self,
Expand Down Expand Up @@ -148,13 +157,17 @@ def loss_from_trajectories(

def assert_finite_gradients(self):
"""Asserts that the gradients are finite."""
if not self.debug:
return
for p in self.parameters():
if p.grad is not None:
if not torch.isfinite(p.grad).all():
raise RuntimeError("GFlowNet has non-finite gradients")

def assert_finite_parameters(self):
"""Asserts that the parameters are finite."""
if not self.debug:
return
for p in self.parameters():
if not torch.isfinite(p).all():
raise RuntimeError("GFlowNet has non-finite parameters")
Expand All @@ -177,6 +190,7 @@ def __init__(
pb: Estimator | None,
constant_pb: bool = False,
log_reward_clip_min: float = float("-inf"),
debug: bool = False,
) -> None:
"""Initializes a PFBasedGFlowNet instance.

Expand All @@ -189,9 +203,10 @@ def __init__(
explicitly by user to ensure that pb is an Estimator except under this
special case.
log_reward_clip_min: If finite, clips log rewards to this value.
debug: If True, keep runtime safety checks active; disable in compiled runs.

"""
super().__init__()
super().__init__(debug=debug)
# Technical note: pb may be constant for a variety of edge cases, for example,
# if all terminal states can be reached with exactly the same number of
# trajectories, and we assume a uniform backward policy, then we can omit the pb
Expand Down Expand Up @@ -365,23 +380,34 @@ def get_scores(
)

assert log_pf_trajectories is not None
total_log_pf_trajectories = log_pf_trajectories.sum(dim=0)
total_log_pb_trajectories = log_pb_trajectories.sum(dim=0)
total_log_pf_trajectories = log_pf_trajectories.sum(dim=0) # [N]
total_log_pb_trajectories = log_pb_trajectories.sum(dim=0) # [N]

log_rewards = trajectories.log_rewards
assert log_rewards is not None

if math.isfinite(self.log_reward_clip_min):
# Fast path: skip clamp when log_reward_clip_min is -inf to avoid extra work.
# TODO: Do we need log reward clamping at all?
if self.log_reward_clip_min != float("-inf"):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)

if torch.any(torch.isinf(total_log_pf_trajectories)):
raise ValueError("Infinite pf logprobs found")
if torch.any(torch.isinf(total_log_pb_trajectories)):
raise ValueError("Infinite pb logprobs found")

assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,)
assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,)
return total_log_pf_trajectories - total_log_pb_trajectories - log_rewards
# Keep runtime safety checks under `debug` to avoid graph breaks in torch.compile.
if self.debug:
if torch.any(torch.isinf(total_log_pf_trajectories)):
raise ValueError("Infinite pf logprobs found")
if torch.any(torch.isinf(total_log_pb_trajectories)):
raise ValueError("Infinite pb logprobs found")
assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,)
assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,)

# Fused (pf - pb) then subtract rewards; keep it branch-free/out-of-place
# to stay friendly to torch.compile graphs.
scores = torch.sub(
total_log_pf_trajectories, total_log_pb_trajectories, alpha=1.0
)
# Subtract rewards in a separate op to avoid in-place mutations (graph-stable)
# while still keeping only one extra temporary.
scores = scores - log_rewards
return scores

def to_training_samples(self, trajectories: Trajectories) -> Trajectories:
"""Returns the input trajectories as training samples.
Expand Down
Loading