From b3715853dfe4b0107e612c4ea6dcf2c8edf98735 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 12 Dec 2025 09:10:10 -0500 Subject: [PATCH 01/16] debug threading through base class, gating of various warnings / asserts --- src/gfn/gflownet/base.py | 32 +++++-- src/gfn/gflownet/detailed_balance.py | 97 ++++++++++++++-------- src/gfn/gflownet/flow_matching.py | 14 ++-- src/gfn/gflownet/sub_trajectory_balance.py | 33 +++++--- src/gfn/gflownet/trajectory_balance.py | 18 ++-- 5 files changed, 134 insertions(+), 60 deletions(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 5542019b..57e4897a 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -48,6 +48,16 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]): log_reward_clip_min = float("-inf") # Default off. + def __init__(self, debug: bool = False) -> None: + """Initialize shared GFlowNet state. + + Args: + debug: If True, keep runtime safety checks and warnings active. Set False + in compiled hot paths to avoid graph breaks; use True in tests/debugging. + """ + super().__init__() + self.debug = debug + @abstractmethod def sample_trajectories( self, @@ -148,6 +158,8 @@ def loss_from_trajectories( def assert_finite_gradients(self): """Asserts that the gradients are finite.""" + if not self.debug: + return for p in self.parameters(): if p.grad is not None: if not torch.isfinite(p.grad).all(): @@ -155,6 +167,8 @@ def assert_finite_gradients(self): def assert_finite_parameters(self): """Asserts that the parameters are finite.""" + if not self.debug: + return for p in self.parameters(): if not torch.isfinite(p).all(): raise RuntimeError("GFlowNet has non-finite parameters") @@ -177,6 +191,7 @@ def __init__( pb: Estimator | None, constant_pb: bool = False, log_reward_clip_min: float = float("-inf"), + debug: bool = False, ) -> None: """Initializes a PFBasedGFlowNet instance. @@ -189,9 +204,10 @@ def __init__( explicitly by user to ensure that pb is an Estimator except under this special case. log_reward_clip_min: If finite, clips log rewards to this value. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ - super().__init__() + super().__init__(debug=debug) # Technical note: pb may be constant for a variety of edge cases, for example, # if all terminal states can be reached with exactly the same number of # trajectories, and we assume a uniform backward policy, then we can omit the pb @@ -374,13 +390,15 @@ def get_scores( if math.isfinite(self.log_reward_clip_min): log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) - if torch.any(torch.isinf(total_log_pf_trajectories)): - raise ValueError("Infinite pf logprobs found") - if torch.any(torch.isinf(total_log_pb_trajectories)): - raise ValueError("Infinite pb logprobs found") + # Keep runtime safety checks under `debug` to avoid graph breaks in torch.compile. + if self.debug: + if torch.any(torch.isinf(total_log_pf_trajectories)): + raise ValueError("Infinite pf logprobs found") + if torch.any(torch.isinf(total_log_pb_trajectories)): + raise ValueError("Infinite pb logprobs found") + assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,) + assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,) - assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,) - assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,) return total_log_pf_trajectories - total_log_pb_trajectories - log_rewards def to_training_samples(self, trajectories: Trajectories) -> Trajectories: diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index feff5462..79fbfc22 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -59,7 +59,10 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]): pb: The backward policy estimator. logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states. - forward_looking: Whether to use the forward-looking GFN loss. + forward_looking: Whether to use the forward-looking GFN loss. When True, + rewards must be defined over edges; this implementation treats the edge + reward as the difference between the successor and current state rewards, + so only valid if the environment follows that assumption. constant_pb: Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. log_reward_clip_min: If finite, clips log rewards to this value. @@ -73,6 +76,7 @@ def __init__( forward_looking: bool = False, constant_pb: bool = False, log_reward_clip_min: float = -float("inf"), + debug: bool = False, ) -> None: """Initializes a DBGFlowNet instance. @@ -82,16 +86,24 @@ def __init__( pb is therefore always 1. logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states. - forward_looking: Whether to use the forward-looking GFN loss. + forward_looking: Whether to use the forward-looking GFN loss. When True, + rewards should be defined over edges; this implementation treats the + edge reward as the difference between the successor and current state + rewards, so only valid if the environment follows that assumption. constant_pb: Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. log_reward_clip_min: If finite, clips log rewards to this value. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ super().__init__( - pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min + pf, + pb, + constant_pb=constant_pb, + log_reward_clip_min=log_reward_clip_min, + debug=debug, ) # Disallow recurrent PF for transition-based DB @@ -158,7 +170,10 @@ def get_pfs_and_pbs( ) def get_scores( - self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = True + self, + env: Env, + transitions: Transitions, + recalculate_all_logprobs: bool = True, ) -> torch.Tensor: r"""Calculates the scores for a batch of transitions. @@ -174,7 +189,8 @@ def get_scores( A tensor of shape (n_transitions,) representing the scores for each transition. """ - if transitions.is_backward: + # Guard bad inputs under debug to avoid graph breaks in torch.compile. + if self.debug and transitions.is_backward: raise ValueError("Backward transitions are not supported") states = transitions.states @@ -183,10 +199,11 @@ def get_scores( if len(states) == 0: return torch.tensor(0.0, device=transitions.device) - check_compatibility(states, actions, transitions) - assert ( - not transitions.states.is_sink_state.any() - ), "Transition from sink state is not allowed. This is a bug." + if self.debug: + check_compatibility(states, actions, transitions) + assert ( + not transitions.states.is_sink_state.any() + ), "Transition from sink state is not allowed. This is a bug." ### Compute log_pf and log_pb log_pf, log_pb = self.get_pfs_and_pbs(transitions, recalculate_all_logprobs) @@ -220,15 +237,17 @@ def get_scores( # Apply forward-looking if applicable if self.forward_looking: - import warnings - - warnings.warn( - "Rewards should be defined over edges in forward-looking settings. " - "The current implementation is a special case of this, where the edge " - "reward is defined as the difference between the reward of two states " - "that the edge connects. If your environment is not the case, " - "forward-looking may be inappropriate." - ) + # Keep explanatory warning only in debug to avoid compile-time graph breaks. + if self.debug: + import warnings + + warnings.warn( + "Rewards should be defined over edges in forward-looking settings. " + "The current implementation is a special case of this, where the edge " + "reward is defined as the difference between the reward of two states " + "that the edge connects. If your environment is not the case, " + "forward-looking may be inappropriate." + ) # Reward calculation can also be conditional. if transitions.conditions is not None: @@ -279,17 +298,23 @@ def loss( transitions: The Transitions object to compute the loss with. recalculate_all_logprobs: Whether to re-evaluate all logprobs. reduction: The reduction method to use ('mean', 'sum', or 'none'). + Run with self.debug=False for improved performance. Returns: The computed detailed balance loss as a tensor. The shape depends on the reduction method. """ - warn_about_recalculating_logprobs(transitions, recalculate_all_logprobs) - scores = self.get_scores(env, transitions, recalculate_all_logprobs) + if self.debug: + warn_about_recalculating_logprobs(transitions, recalculate_all_logprobs) + scores = self.get_scores( + env, + transitions, + recalculate_all_logprobs=recalculate_all_logprobs, + ) scores = scores**2 loss = loss_reduce(scores, reduction) - if torch.isnan(loss).any(): + if self.debug and torch.isnan(loss).any(): raise ValueError("loss is nan") return loss @@ -327,6 +352,7 @@ def __init__( pf: Estimator, pb: Estimator | None, constant_pb: bool = False, + debug: bool = False, ) -> None: """Initializes a ModifiedDBGFlowNet instance. @@ -334,12 +360,15 @@ def __init__( pf: Forward policy estimator. pb: Backward policy estimator or None. constant_pb: See base class. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ - super().__init__(pf, pb, constant_pb=constant_pb) + super().__init__(pf, pb, constant_pb=constant_pb, debug=debug) def get_scores( - self, transitions: Transitions, recalculate_all_logprobs: bool = True + self, + transitions: Transitions, + recalculate_all_logprobs: bool = True, ) -> torch.Tensor: """Calculates DAG-GFN-style modified detailed balance scores. @@ -360,7 +389,7 @@ def get_scores( Returns: A tensor of shape (n_transitions,) containing the scores for each transition. """ - if transitions.is_backward: + if self.debug and transitions.is_backward: raise ValueError("Backward transitions are not supported") if len(transitions) == 0: @@ -372,7 +401,8 @@ def get_scores( actions = transitions.actions[mask] all_log_rewards = transitions.all_log_rewards[mask] - check_compatibility(states, actions, transitions) + if self.debug: + check_compatibility(states, actions, transitions) if transitions.conditions is not None: with has_conditions_exception_handler("pf", self.pf): @@ -392,9 +422,11 @@ def get_scores( else: # Evaluate the log PF of the actions sampled off policy. valid_log_pf_actions = pf_dist.log_prob(actions.tensor) - valid_log_pf_s_exit = pf_dist.log_prob( - torch.full_like(actions.tensor, actions.__class__.exit_action[0].item()) - ) + # Avoid .item() in hot path to stay compile-friendly; broadcast exit_action tensor. + exit_action_tensor = actions.__class__.exit_action.to( + actions.tensor.device, dtype=actions.tensor.dtype + ).expand_as(actions.tensor) + valid_log_pf_s_exit = pf_dist.log_prob(exit_action_tensor) # The following two lines are slightly inefficient, given that most # next_states are also states, for which we already did a forward pass. @@ -407,9 +439,7 @@ def get_scores( valid_log_pf_s_prime_exit = self.pf.to_probability_distribution( valid_next_states, module_output - ).log_prob( - torch.full_like(actions.tensor, actions.__class__.exit_action[0].item()) - ) + ).log_prob(exit_action_tensor[: len(valid_next_states)]) non_exit_actions = actions[~actions.is_exit] @@ -435,7 +465,7 @@ def get_scores( targets = all_log_rewards[:, 1] + valid_log_pb_actions + valid_log_pf_s_exit scores = preds - targets - if torch.any(torch.isinf(scores)): + if self.debug and torch.any(torch.isinf(scores)): raise ValueError("scores contains inf") return scores @@ -462,7 +492,8 @@ def loss( del env warn_about_recalculating_logprobs(transitions, recalculate_all_logprobs) scores = self.get_scores( - transitions, recalculate_all_logprobs=recalculate_all_logprobs + transitions, + recalculate_all_logprobs=recalculate_all_logprobs, ) scores = scores**2 return loss_reduce(scores, reduction) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index c1a70d59..eb5f5a1d 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -41,15 +41,18 @@ class FMGFlowNet(GFlowNet[StatesContainer[DiscreteStates]]): the default (non-recurrent) PolicyMixin interface. """ - def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): + def __init__( + self, logF: DiscretePolicyEstimator, alpha: float = 1.0, debug: bool = False + ): """Initializes a FMGFlowNet instance. Args: logF: A DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator for estimating the log flow of the edges (states -> next_states). alpha: A scalar weight for the reward matching loss. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ - super().__init__() + super().__init__(debug=debug) assert isinstance( logF, PolicyMixin ), "logF must use the default PolicyMixin interface" @@ -254,9 +257,10 @@ def loss( The computed flow matching loss as a tensor. The shape depends on the reduction method. """ - assert isinstance(states_container.intermediary_states, DiscreteStates) - assert isinstance(states_container.terminating_states, DiscreteStates) - if recalculate_all_logprobs: + if self.debug: + assert isinstance(states_container.intermediary_states, DiscreteStates) + assert isinstance(states_container.terminating_states, DiscreteStates) + if recalculate_all_logprobs and self.debug: warnings.warn( "recalculate_all_logprobs is not used for FM. Ignoring the argument." ) diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 37296c61..9d5ddbdf 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -83,6 +83,7 @@ def __init__( log_reward_clip_min: float = -float("inf"), forward_looking: bool = False, constant_pb: bool = False, + debug: bool = False, ): """Initializes a SubTBGFlowNet instance. @@ -100,10 +101,15 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ super().__init__( - pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min + pf, + pb, + constant_pb=constant_pb, + log_reward_clip_min=log_reward_clip_min, + debug=debug, ) assert any( isinstance(logF, cls) @@ -535,12 +541,16 @@ def loss( trajectories: The batch of trajectories to compute the loss with. recalculate_all_logprobs: Whether to re-evaluate all logprobs. reduction: The reduction method to use ('mean', 'sum', or 'none'). + Note: for geometric-based sub-trajectory weighting, 'mean' is not + supported and is coerced to 'sum' (a warning is emitted when + debug=True). Returns: The computed sub-trajectory balance loss as a tensor. The shape depends on the reduction method. """ - warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) + if self.debug: + warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) # Get all scores and masks from the trajectories. scores, flattening_masks = self.get_scores( env, trajectories, recalculate_all_logprobs=recalculate_all_logprobs @@ -576,7 +586,8 @@ def loss( weights = ratio * ( L ** torch.arange(max_len, device=per_length_losses.device) ) - assert (weights.sum() - 1.0).abs() < 1e-5, f"{weights.sum()}" + if self.debug: + assert (weights.sum() - 1.0).abs() < 1e-5, f"{weights.sum()}" return (per_length_losses * weights).sum() # TODO: we need to know what reductions are valid for each weighting method. @@ -593,17 +604,19 @@ def loss( raise ValueError(f"Unknown weighting method {self.weighting}") flat_contributions = contributions[~flattening_mask] - assert ( - flat_contributions.sum() - 1.0 - ).abs() < 1e-5, f"{flat_contributions.sum()}" + if self.debug: + assert ( + flat_contributions.sum() - 1.0 + ).abs() < 1e-5, f"{flat_contributions.sum()}" final_scores = flat_contributions * all_scores[~flattening_mask].pow(2) # TODO: default was sum, should we allow mean? if reduction == "mean": - warnings.warn( - "Mean reduction is not supported for SubTBGFlowNet with geometric " - "weighting, using sum instead." - ) + if self.debug: + warnings.warn( + "Mean reduction is not supported for SubTBGFlowNet with geometric " + "weighting, using sum instead." + ) reduction = "sum" return loss_reduce(final_scores, reduction) diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 8e2d6e6b..4cea01fa 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -49,6 +49,7 @@ def __init__( init_logZ: float = 0.0, constant_pb: bool = False, log_reward_clip_min: float = -float("inf"), + debug: bool = False, ): """Initializes a TBGFlowNet instance. @@ -63,9 +64,14 @@ def __init__( is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. log_reward_clip_min: If finite, clips log rewards to this value. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ super().__init__( - pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min + pf, + pb, + constant_pb=constant_pb, + log_reward_clip_min=log_reward_clip_min, + debug=debug, ) self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) @@ -109,7 +115,8 @@ def loss( reduction method. """ del env # unused - warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) + if self.debug: + warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) scores = self.get_scores( trajectories, recalculate_all_logprobs=recalculate_all_logprobs ) @@ -126,7 +133,7 @@ def loss( logZ = cast(torch.Tensor, logZ) scores = (scores + logZ.squeeze()).pow(2) loss = loss_reduce(scores, reduction) - if torch.isnan(loss).any(): + if self.debug and torch.isnan(loss).any(): raise ValueError("loss is nan") return loss @@ -170,13 +177,14 @@ def loss( the reduction method. """ del env # unused - warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) + if self.debug: + warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) scores = self.get_scores( trajectories, recalculate_all_logprobs=recalculate_all_logprobs ) scores = (scores - scores.mean()).pow(2) loss = loss_reduce(scores, reduction) - if torch.isnan(loss).any(): + if self.debug and torch.isnan(loss).any(): raise ValueError("loss is NaN.") return loss From b13b5b7eb1af4df2e36530558f934d32e989974c Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 12 Dec 2025 09:11:03 -0500 Subject: [PATCH 02/16] added loss information for special cases --- docs/source/guides/losses.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/guides/losses.md b/docs/source/guides/losses.md index 45c79f15..b6535b18 100644 --- a/docs/source/guides/losses.md +++ b/docs/source/guides/losses.md @@ -5,7 +5,7 @@ GFlowNets can be trained with different losses, each of which requires a differe Currently, the implemented losses are: - Flow Matching: This is the original loss, and is mostly made available for completeness. It is slow to compute this loss, and also hard to optimize, so is generally not recommended for any significantly hard to learn problem. -- Detailed Balance (and it's modified variant). +- Detailed Balance (and its modified variant). In forward-looking mode, rewards must be defined on edges; the current implementation treats the edge reward as the difference between the successor and current state rewards, so only enable this when that matches your environment. - Trajectory Balance -- Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined [here](https://www.semanticscholar.org/reader/f2c32fe3f7f3e2e9d36d833e32ec55fc93f900f5). Other strategies exist and are implemented [here](https://github.com/gfnorg/torchgfn/tree/master/src/gfn/losses/sub_trajectory_balance.py). +- Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined [here](https://www.semanticscholar.org/reader/f2c32fe3f7f3e2e9d36d833e32ec55fc93f900f5). Other strategies exist and are implemented [here](https://github.com/gfnorg/torchgfn/tree/master/src/gfn/losses/sub_trajectory_balance.py). When using geometric-based weighting, the 'mean' reduction is not supported; requests for a mean reduction are coerced to a sum (a warning is emitted when debug is enabled). - Log Partition Variance loss. Introduced [here](https://arxiv.org/abs/2302.05446) From d83fa5dc3383aab40c5c2b73098d97592e21226b Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 12 Dec 2025 09:12:38 -0500 Subject: [PATCH 03/16] wandb removed by default, fixed imports for script / module runs --- tutorials/examples/train_hypergrid.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 6f3db577..193e4cce 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -58,10 +58,18 @@ from gfn.utils.common import Timer, set_seed from gfn.utils.distributed import DistributedContext, initialize_distributed_compute from gfn.utils.modules import MLP, DiscreteUniform, Tabular -from tutorials.examples.multinode.spawn_policy import ( - AsyncSelectiveAveragingPolicy, - AverageAllPolicy, -) + +# Allow running both as a package module and as a standalone script +try: + from tutorials.examples.multinode.spawn_policy import ( + AsyncSelectiveAveragingPolicy, + AverageAllPolicy, + ) +except ImportError: + from multinode.spawn_policy import ( + AsyncSelectiveAveragingPolicy, + AverageAllPolicy, + ) class ModesReplayBufferManager(ReplayBufferManager): @@ -1326,13 +1334,13 @@ def cleanup(): parser.add_argument( "--wandb_project", type=str, - default="torchgfn", + default="", help="Name of the wandb project. If empty, don't use wandb", ) parser.add_argument( "--wandb_entity", type=str, - default="torchgfn", + default="", help="Name of the wandb entity. If empty, don't use wandb", ) parser.add_argument( From 5e50be75d383a1f58961c6aabc58218ec7ab8756 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 12 Dec 2025 15:23:43 -0500 Subject: [PATCH 04/16] switched seeding to use a local generator to prevent side-effects --- src/gfn/gym/bitSequence.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/gfn/gym/bitSequence.py b/src/gfn/gym/bitSequence.py index ed0ecb95..fadcc9a5 100644 --- a/src/gfn/gym/bitSequence.py +++ b/src/gfn/gym/bitSequence.py @@ -449,10 +449,6 @@ def make_modes_set(self, seed) -> torch.Tensor: ValueError: If the number of requested modes exceeds the number of possible unique sequences. """ - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - if self.H is None: self.H = torch.tensor( [ @@ -472,11 +468,16 @@ def make_modes_set(self, seed) -> torch.Tensor: "Not enough unique sequences available for the set of modes." ) + g = torch.Generator(device=self.device) + g.manual_seed(seed) + unique_indices = set() unique_sequences = [] while len(unique_sequences) < self.n_modes: - candidate = tuple(torch.randint(0, self.H.shape[0], (m,)).tolist()) + candidate = tuple( + torch.randint(0, self.H.shape[0], (m,), generator=g).tolist() + ) if candidate not in unique_indices: unique_indices.add(candidate) unique_sequences.append(candidate) From 4621d71b0c75e18a933f3f1c0eb040fae88d1ae9 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 12 Dec 2025 16:53:36 -0500 Subject: [PATCH 05/16] tests for vectorized versions matching legacy --- testing/test_gflownet.py | 105 +++++++++++++++++++++++++++++ testing/test_gflownets.py | 135 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 240 insertions(+) create mode 100644 testing/test_gflownets.py diff --git a/testing/test_gflownet.py b/testing/test_gflownet.py index 953d217f..48cbd2e2 100644 --- a/testing/test_gflownet.py +++ b/testing/test_gflownet.py @@ -1,11 +1,19 @@ +import pytest +import torch + from gfn.containers import StatesContainer, Trajectories from gfn.containers.base import Container from gfn.estimators import DiscretePolicyEstimator from gfn.gflownet import FMGFlowNet, TBGFlowNet +from gfn.gflownet.base import loss_reduce from gfn.gym import Box, HyperGrid from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP from gfn.preprocessors import KHotPreprocessor from gfn.states import DiscreteStates +from gfn.utils.handlers import ( + has_conditions_exception_handler, + no_conditions_exception_handler, +) from gfn.utils.modules import MLP @@ -86,3 +94,100 @@ def test_pytorch_inheritance(): assert hasattr( fmgflownet.state_dict(), "__dict__" ), "Expected gflownet to have indexable state_dict() method inherited from nn.Module" + + +@pytest.mark.parametrize("seed", [0, 12, 47, 67]) +def test_flow_matching_vectorized_matches_original(seed): + torch.manual_seed(seed) + env = HyperGrid(ndim=2) + preprocessor = KHotPreprocessor(ndim=env.ndim, height=env.height) + module = MLP(input_dim=preprocessor.output_dim, output_dim=env.n_actions) + estimator = DiscretePolicyEstimator( + module, n_actions=env.n_actions, preprocessor=preprocessor + ) + gflownet = FMGFlowNet(estimator) + + trajectories = gflownet.sample_trajectories(env, n=6) + states_container = gflownet.to_training_samples(trajectories) + states = states_container.intermediary_states + conditions = states_container.intermediary_conditions + + if len(states) == 0: + # If the sample produced only terminal states, resample with more trajectories. + trajectories = gflownet.sample_trajectories(env, n=12) + states_container = gflownet.to_training_samples(trajectories) + states = states_container.intermediary_states + conditions = states_container.intermediary_conditions + + assert len(states) > 0 + + def flow_matching_loss_original( + self, env, states, conditions, reduction: str = "mean" + ): + if len(states) == 0: + return torch.tensor(0.0, device=states.device) + assert len(states.batch_shape) == 1 + assert not torch.any(states.is_initial_state) + incoming_log_flows = torch.full_like( + states.backward_masks, -float("inf"), dtype=torch.get_default_dtype() + ) + outgoing_log_flows = torch.full_like( + states.forward_masks, -float("inf"), dtype=torch.get_default_dtype() + ) + for action_idx in range(env.n_actions - 1): + valid_backward_mask = states.backward_masks[:, action_idx] + valid_forward_mask = states.forward_masks[:, action_idx] + valid_backward_states = states[valid_backward_mask] + valid_forward_states = states[valid_forward_mask] + backward_actions = torch.full_like( + valid_backward_states.backward_masks[:, 0], action_idx, dtype=torch.long + ).unsqueeze(-1) + backward_actions = env.actions_from_tensor(backward_actions) + valid_backward_states_parents = env._backward_step( + valid_backward_states, backward_actions + ) + if conditions is not None: + valid_backward_conditions = conditions[valid_backward_mask] + valid_forward_conditions = conditions[valid_forward_mask] + with has_conditions_exception_handler("logF", self.logF): + incoming_log_flows[valid_backward_mask, action_idx] = self.logF( + valid_backward_states_parents, + valid_backward_conditions, + )[:, action_idx] + outgoing_log_flows[valid_forward_mask, action_idx] = self.logF( + valid_forward_states, + valid_forward_conditions, + )[:, action_idx] + else: + with no_conditions_exception_handler("logF", self.logF): + incoming_log_flows[valid_backward_mask, action_idx] = self.logF( + valid_backward_states_parents, + )[:, action_idx] + outgoing_log_flows[valid_forward_mask, action_idx] = self.logF( + valid_forward_states, + )[:, action_idx] + valid_forward_mask = states.forward_masks[:, -1] + if conditions is not None: + with has_conditions_exception_handler("logF", self.logF): + outgoing_log_flows[valid_forward_mask, -1] = self.logF( + states[valid_forward_mask], + conditions[valid_forward_mask], + )[:, -1] + else: + with no_conditions_exception_handler("logF", self.logF): + outgoing_log_flows[valid_forward_mask, -1] = self.logF( + states[valid_forward_mask], + )[:, -1] + log_incoming_flows = torch.logsumexp(incoming_log_flows, dim=-1) + log_outgoing_flows = torch.logsumexp(outgoing_log_flows, dim=-1) + scores = (log_incoming_flows - log_outgoing_flows).pow(2) + return loss_reduce(scores, reduction) + + loss_original = flow_matching_loss_original( + gflownet, env, states, conditions, reduction="mean" + ) + loss_vectorized = gflownet.flow_matching_loss( + env, states, conditions, reduction="mean" + ) + + torch.testing.assert_close(loss_vectorized, loss_original) diff --git a/testing/test_gflownets.py b/testing/test_gflownets.py new file mode 100644 index 00000000..5eff8b03 --- /dev/null +++ b/testing/test_gflownets.py @@ -0,0 +1,135 @@ +from types import MethodType + +import pytest +import torch + +from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet + + +class _DummyTrajectories: + """Minimal trajectories carrier for get_scores vectorization test.""" + + def __init__(self, terminating_idx: torch.Tensor, max_length: int): + self.terminating_idx = terminating_idx + self.max_length = max_length + self.n_trajectories = terminating_idx.shape[0] + + def __len__(self) -> int: + return self.n_trajectories + + +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_subtb_get_scores_vectorized_matches_original(seed: int): + torch.manual_seed(seed) + max_len = 3 + n_traj = 4 + + # Synthetic inputs for the get_scores pipeline. + terminating_idx = torch.tensor([1, 2, 3, 2]) + log_pf_trajectories = torch.randn(max_len, n_traj) + log_pb_trajectories = torch.randn(max_len, n_traj) + log_state_flows = torch.randn(max_len, n_traj) + sink_states_mask = torch.zeros(max_len, n_traj, dtype=torch.bool) + is_terminal_mask = torch.zeros(max_len, n_traj, dtype=torch.bool) + + preds_list = [torch.randn(max_len + 1 - i, n_traj) for i in range(1, max_len + 1)] + targets_list = [torch.randn(max_len + 1 - i, n_traj) for i in range(1, max_len + 1)] + + trajectories = _DummyTrajectories( + terminating_idx=terminating_idx, max_length=max_len + ) + env = object() # Unused by the monkeypatched methods. + + # Build a SubTBGFlowNet instance without running its heavy __init__. + model = SubTBGFlowNet.__new__(SubTBGFlowNet) + torch.nn.Module.__init__(model) + model.debug = False + model.log_reward_clip_min = float("-inf") + + # Monkeypatch the dependencies used inside get_scores to deterministic tensors. + model.get_pfs_and_pbs = MethodType( + lambda self, traj, recalculate_all_logprobs=True: ( + log_pf_trajectories, + log_pb_trajectories, + ), + model, + ) + model.calculate_log_state_flows = MethodType( + lambda self, _env, _traj, _log_pf: log_state_flows, model + ) + model.calculate_masks = MethodType( + lambda self, _log_state_flows, _traj: (sink_states_mask, is_terminal_mask), + model, + ) + model.calculate_preds = MethodType( + lambda self, _log_pf_cum, _log_state_flows, i: preds_list[i - 1], model + ) + model.calculate_targets = MethodType( + lambda self, _traj, _preds, _log_pb_cum, _log_state_flows, _term_mask, _sink_mask, i: targets_list[ + i - 1 + ], + model, + ) + + def original_get_scores(self, env, trajectories, recalculate_all_logprobs=True): + log_pf_trajectories_, log_pb_trajectories_ = self.get_pfs_and_pbs( + trajectories, recalculate_all_logprobs=recalculate_all_logprobs + ) + + log_pf_trajectories_cum = self.cumulative_logprobs( + trajectories, log_pf_trajectories_ + ) + log_pb_trajectories_cum = self.cumulative_logprobs( + trajectories, log_pb_trajectories_ + ) + + log_state_flows_ = self.calculate_log_state_flows( + env, trajectories, log_pf_trajectories_ + ) + sink_states_mask_, is_terminal_mask_ = self.calculate_masks( + log_state_flows_, trajectories + ) + + flattening_masks_orig = [] + scores_orig = [] + for i in range(1, 1 + trajectories.max_length): + preds = self.calculate_preds(log_pf_trajectories_cum, log_state_flows_, i) + targets = self.calculate_targets( + trajectories, + preds, + log_pb_trajectories_cum, + log_state_flows_, + is_terminal_mask_, + sink_states_mask_, + i, + ) + + flattening_mask = trajectories.terminating_idx.lt( + torch.arange( + i, + trajectories.max_length + 1, + device=trajectories.terminating_idx.device, + ).unsqueeze(-1) + ) + + flat_preds = preds[~flattening_mask] + if self.debug and torch.any(torch.isnan(flat_preds)): + raise ValueError("NaN in preds") + + flat_targets = targets[~flattening_mask] + if self.debug and torch.any(torch.isnan(flat_targets)): + raise ValueError("NaN in targets") + + flattening_masks_orig.append(flattening_mask) + scores_orig.append(preds - targets) + + return scores_orig, flattening_masks_orig + + orig_scores, orig_masks = original_get_scores(model, env, trajectories) + vec_scores, vec_masks = model.get_scores(env, trajectories) # type: ignore + + assert len(orig_scores) == len(vec_scores) == trajectories.max_length + for orig, vec in zip(orig_scores, vec_scores): + torch.testing.assert_close(vec, orig) + for orig_m, vec_m in zip(orig_masks, vec_masks): + assert torch.equal(vec_m, orig_m) From 4b8d30a0f624b8ae777e7b6c79b7c533fd3e0732 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 12 Dec 2025 22:08:46 -0500 Subject: [PATCH 06/16] checkpoint before we re-try to vectorize the subtb loss calculation --- src/gfn/gflownet/detailed_balance.py | 20 ++- src/gfn/gflownet/flow_matching.py | 97 ++++++++----- testing/test_gflownets.py | 166 ++++++++++++++-------- tutorials/examples/test_scripts.py | 7 +- tutorials/examples/train_bit_sequences.py | 17 ++- 5 files changed, 191 insertions(+), 116 deletions(-) diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 79fbfc22..f29fa919 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -1,4 +1,3 @@ -import math from typing import Tuple import torch @@ -205,10 +204,10 @@ def get_scores( not transitions.states.is_sink_state.any() ), "Transition from sink state is not allowed. This is a bug." - ### Compute log_pf and log_pb + # Compute log_pf and log_pb log_pf, log_pb = self.get_pfs_and_pbs(transitions, recalculate_all_logprobs) - ### Compute log_F_s + # Compute log_F_s # LogF is potentially a conditional computation. if transitions.conditions is not None: with has_conditions_exception_handler("logF", self.logF): @@ -217,7 +216,7 @@ def get_scores( with no_conditions_exception_handler("logF", self.logF): log_F_s = self.logF(states).squeeze(-1) - ### Compute log_F_s_next + # Compute log_F_s_next log_F_s_next = torch.zeros_like(log_F_s) is_terminating = transitions.is_terminating is_intermediate = ~is_terminating @@ -258,9 +257,8 @@ def get_scores( else: log_rewards_state = env.log_reward(states) log_rewards_next = env.log_reward(interm_next_states) - if math.isfinite(self.log_reward_clip_min): - log_rewards_state = log_rewards_state.clamp_min(self.log_reward_clip_min) - log_rewards_next = log_rewards_next.clamp_min(self.log_reward_clip_min) + log_rewards_state = log_rewards_state.clamp_min(self.log_reward_clip_min) + log_rewards_next = log_rewards_next.clamp_min(self.log_reward_clip_min) log_F_s = log_F_s + log_rewards_state log_F_s_next[is_intermediate] = ( @@ -270,11 +268,10 @@ def get_scores( # Assign log_F_s_next for terminating transitions as log_rewards log_rewards = transitions.log_rewards assert log_rewards is not None - if math.isfinite(self.log_reward_clip_min): - log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) + log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) log_F_s_next[is_terminating] = log_rewards[is_terminating] - ### Compute scores + # Compute scores preds = log_pf + log_F_s targets = log_pb + log_F_s_next scores = preds - targets @@ -490,7 +487,8 @@ def loss( on the reduction method. """ del env - warn_about_recalculating_logprobs(transitions, recalculate_all_logprobs) + if self.debug: + warn_about_recalculating_logprobs(transitions, recalculate_all_logprobs) scores = self.get_scores( transitions, recalculate_all_logprobs=recalculate_all_logprobs, diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index eb5f5a1d..b07c2559 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -126,54 +126,75 @@ def flow_matching_loss( assert len(states.batch_shape) == 1 assert not torch.any(states.is_initial_state) - incoming_log_flows = torch.full_like( - states.backward_masks, -float("inf"), dtype=torch.get_default_dtype() + incoming_log_flows = torch.full( + states.backward_masks.shape, + -float("inf"), + device=states.device, + dtype=torch.get_default_dtype(), ) - outgoing_log_flows = torch.full_like( - states.forward_masks, -float("inf"), dtype=torch.get_default_dtype() + outgoing_log_flows = torch.full( + states.forward_masks.shape, + -float("inf"), + device=states.device, + dtype=torch.get_default_dtype(), ) - # TODO: Need to vectorize this loop. - for action_idx in range(env.n_actions - 1): - valid_backward_mask = states.backward_masks[:, action_idx] - valid_forward_mask = states.forward_masks[:, action_idx] - valid_backward_states = states[valid_backward_mask] - valid_forward_states = states[valid_forward_mask] - - backward_actions = torch.full_like( - valid_backward_states.backward_masks[:, 0], action_idx, dtype=torch.long - ).unsqueeze(-1) - backward_actions = env.actions_from_tensor(backward_actions) - - valid_backward_states_parents = env._backward_step( - valid_backward_states, backward_actions - ) + # Vectorized over actions. + valid_backward = states.backward_masks + backward_indices = valid_backward.nonzero(as_tuple=False) + if backward_indices.numel() > 0: + backward_state_idx = backward_indices[:, 0] # time index + backward_action_idx = backward_indices[:, 1] # action index + backward_states = states[backward_state_idx] + backward_actions_tensor = backward_action_idx.view(-1, 1) + backward_actions = env.actions_from_tensor(backward_actions_tensor) + backward_parents = env._backward_step(backward_states, backward_actions) # type: ignore + + # calculate log flows of backward actions. + if conditions is not None: + backward_conditions = conditions[backward_state_idx] + with has_conditions_exception_handler("logF", self.logF): + backward_logF = ( + self.logF(backward_parents, backward_conditions) + .gather(1, backward_action_idx.view(-1, 1)) + .squeeze(1) + ) + else: + with no_conditions_exception_handler("logF", self.logF): + backward_logF = ( + self.logF(backward_parents) + .gather(1, backward_action_idx.view(-1, 1)) + .squeeze(1) + ) + + incoming_log_flows[backward_state_idx, backward_action_idx] = backward_logF + + # Vectorized over all non-exit forward actions. + valid_forward = states.forward_masks[:, :-1] + forward_indices = valid_forward.nonzero(as_tuple=False) + if forward_indices.numel() > 0: + forward_state_idx = forward_indices[:, 0] + forward_action_idx = forward_indices[:, 1] + forward_states = states[forward_state_idx] if conditions is not None: # Mask out only valid conditions elements. - valid_backward_conditions = conditions[valid_backward_mask] - valid_forward_conditions = conditions[valid_forward_mask] - + forward_conditions = conditions[forward_state_idx] with has_conditions_exception_handler("logF", self.logF): - incoming_log_flows[valid_backward_mask, action_idx] = self.logF( - valid_backward_states_parents, - valid_backward_conditions, - )[:, action_idx] - - outgoing_log_flows[valid_forward_mask, action_idx] = self.logF( - valid_forward_states, - valid_forward_conditions, - )[:, action_idx] - + forward_logF = ( + self.logF(forward_states, forward_conditions) + .gather(1, forward_action_idx.view(-1, 1)) + .squeeze(1) + ) else: with no_conditions_exception_handler("logF", self.logF): - incoming_log_flows[valid_backward_mask, action_idx] = self.logF( - valid_backward_states_parents, - )[:, action_idx] + forward_logF = ( + self.logF(forward_states) + .gather(1, forward_action_idx.view(-1, 1)) + .squeeze(1) + ) - outgoing_log_flows[valid_forward_mask, action_idx] = self.logF( - valid_forward_states, - )[:, action_idx] + outgoing_log_flows[forward_state_idx, forward_action_idx] = forward_logF # Now the exit action. valid_forward_mask = states.forward_masks[:, -1] diff --git a/testing/test_gflownets.py b/testing/test_gflownets.py index 5eff8b03..e27a7c74 100644 --- a/testing/test_gflownets.py +++ b/testing/test_gflownets.py @@ -1,75 +1,64 @@ -from types import MethodType - import pytest import torch +from gfn.containers.trajectories import Trajectories +from gfn.estimators import DiscretePolicyEstimator, ScalarEstimator from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet - - -class _DummyTrajectories: - """Minimal trajectories carrier for get_scores vectorization test.""" - - def __init__(self, terminating_idx: torch.Tensor, max_length: int): - self.terminating_idx = terminating_idx - self.max_length = max_length - self.n_trajectories = terminating_idx.shape[0] - - def __len__(self) -> int: - return self.n_trajectories +from gfn.gym.hypergrid import HyperGrid +from gfn.preprocessors import KHotPreprocessor +from gfn.samplers import Sampler +from gfn.utils.modules import MLP @pytest.mark.parametrize("seed", [0, 1, 2]) def test_subtb_get_scores_vectorized_matches_original(seed: int): torch.manual_seed(seed) - max_len = 3 n_traj = 4 - # Synthetic inputs for the get_scores pipeline. - terminating_idx = torch.tensor([1, 2, 3, 2]) - log_pf_trajectories = torch.randn(max_len, n_traj) - log_pb_trajectories = torch.randn(max_len, n_traj) - log_state_flows = torch.randn(max_len, n_traj) - sink_states_mask = torch.zeros(max_len, n_traj, dtype=torch.bool) - is_terminal_mask = torch.zeros(max_len, n_traj, dtype=torch.bool) - - preds_list = [torch.randn(max_len + 1 - i, n_traj) for i in range(1, max_len + 1)] - targets_list = [torch.randn(max_len + 1 - i, n_traj) for i in range(1, max_len + 1)] - - trajectories = _DummyTrajectories( - terminating_idx=terminating_idx, max_length=max_len + # Deterministic HyperGrid env and frozen estimators so real methods can run. + env = HyperGrid(ndim=2, height=3, device="cpu", debug=False) + preproc = KHotPreprocessor(height=env.height, ndim=env.ndim) + + # Tiny MLPs with random weights (frozen for determinism). + module_pf = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions) + module_pb = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions - 1) + module_logF = MLP(input_dim=preproc.output_dim, output_dim=1) + for mod in (module_pf, module_pb, module_logF): + for p in mod.parameters(): + p.requires_grad_(False) + + pf = DiscretePolicyEstimator( + module=module_pf, + n_actions=env.n_actions, + preprocessor=preproc, + is_backward=False, ) - env = object() # Unused by the monkeypatched methods. + pb = DiscretePolicyEstimator( + module=module_pb, n_actions=env.n_actions, preprocessor=preproc, is_backward=True + ) + logF = ScalarEstimator(module=module_logF, preprocessor=preproc) - # Build a SubTBGFlowNet instance without running its heavy __init__. - model = SubTBGFlowNet.__new__(SubTBGFlowNet) - torch.nn.Module.__init__(model) + # Initialize model via __init__ to set up real methods. + model = SubTBGFlowNet( + pf=pf, pb=pb, logF=logF, weighting="geometric_within", lamda=0.9 + ) model.debug = False model.log_reward_clip_min = float("-inf") - - # Monkeypatch the dependencies used inside get_scores to deterministic tensors. - model.get_pfs_and_pbs = MethodType( - lambda self, traj, recalculate_all_logprobs=True: ( - log_pf_trajectories, - log_pb_trajectories, - ), - model, - ) - model.calculate_log_state_flows = MethodType( - lambda self, _env, _traj, _log_pf: log_state_flows, model - ) - model.calculate_masks = MethodType( - lambda self, _log_state_flows, _traj: (sink_states_mask, is_terminal_mask), - model, - ) - model.calculate_preds = MethodType( - lambda self, _log_pf_cum, _log_state_flows, i: preds_list[i - 1], model - ) - model.calculate_targets = MethodType( - lambda self, _traj, _preds, _log_pb_cum, _log_state_flows, _term_mask, _sink_mask, i: targets_list[ - i - 1 - ], - model, + model.eval() + pf.eval() + pb.eval() + logF.eval() + + # Sample a deterministic batch of trajectories with frozen estimators. + sampler = Sampler(estimator=pf) + trajectories: Trajectories = sampler.sample_trajectories( + env, + n=n_traj, + epsilon=0.0, + save_logprobs=True, + save_estimator_outputs=False, ) + max_len = trajectories.max_length # noqa: F841 used implicitly by shapes def original_get_scores(self, env, trajectories, recalculate_all_logprobs=True): log_pf_trajectories_, log_pb_trajectories_ = self.get_pfs_and_pbs( @@ -125,11 +114,62 @@ def original_get_scores(self, env, trajectories, recalculate_all_logprobs=True): return scores_orig, flattening_masks_orig - orig_scores, orig_masks = original_get_scores(model, env, trajectories) - vec_scores, vec_masks = model.get_scores(env, trajectories) # type: ignore + def normalize_scores_masks( + scores, masks, trajectories: Trajectories + ) -> tuple[torch.Tensor, torch.Tensor]: + """Convert list outputs to padded tensors; pass tensors through unchanged.""" + if isinstance(scores, torch.Tensor): + assert isinstance(masks, torch.Tensor) + return scores, masks + + assert isinstance(scores, (list, tuple)) + assert isinstance(masks, (list, tuple)) + + max_len = trajectories.max_length + n_traj = ( + trajectories.n_trajectories + if hasattr(trajectories, "n_trajectories") + else len(trajectories) + ) + device = trajectories.terminating_idx.device + dtype = scores[0].dtype + + scores_padded = torch.zeros( + (max_len, max_len, n_traj), dtype=dtype, device=device + ) + masks_padded = torch.ones( + (max_len, max_len, n_traj), dtype=torch.bool, device=device + ) + + for i, (s, m) in enumerate(zip(scores, masks), start=1): + seq_len = s.shape[0] + scores_padded[i - 1, :seq_len] = s + masks_padded[i - 1, :seq_len] = m + + return scores_padded, masks_padded + + # Recompute logprobs to ensure PF/PB are evaluated for both paths. + orig_scores_list, orig_masks_list = original_get_scores( + model, env, trajectories, recalculate_all_logprobs=True + ) + vec_scores, vec_masks = model.get_scores( + env, trajectories, recalculate_all_logprobs=True + ) # type: ignore + + vec_scores_t, vec_masks_t = normalize_scores_masks( + vec_scores, vec_masks, trajectories + ) + orig_scores_t, orig_masks_t = normalize_scores_masks( + orig_scores_list, orig_masks_list, trajectories + ) + + valid_mask = ~orig_masks_t + if not torch.allclose( + vec_scores_t[valid_mask], orig_scores_t[valid_mask], equal_nan=True + ): + max_diff = (vec_scores_t[valid_mask] - orig_scores_t[valid_mask]).abs().max() + raise AssertionError( + f"Score mismatch on valid positions; max_abs_diff={max_diff.item()}" + ) - assert len(orig_scores) == len(vec_scores) == trajectories.max_length - for orig, vec in zip(orig_scores, vec_scores): - torch.testing.assert_close(vec, orig) - for orig_m, vec_m in zip(orig_masks, vec_masks): - assert torch.equal(vec_m, orig_m) + torch.testing.assert_close(vec_masks_t, orig_masks_t, equal_nan=True) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index cacd2637..e4bcbf55 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -211,7 +211,7 @@ class BoxArgs(CommonArgs): @dataclass class BitSequenceArgs(CommonArgs): - n_iterations: int = 5000 + n_iterations: int = 1000 word_size: int = 1 seq_size: int = 4 n_modes: int = 2 @@ -220,6 +220,7 @@ class BitSequenceArgs(CommonArgs): lr_Z: 1e-2 seed: int = 0 batch_size: int = 32 + deterministic_mode: bool = True @dataclass @@ -657,7 +658,9 @@ def test_bitsequence(seq_size: int, n_modes: int): args = BitSequenceArgs(seq_size=seq_size, n_modes=n_modes) final_l1_dist = train_bitsequence_main(args) assert final_l1_dist is not None - # print(f"[DEBUG] BitSequence seq_size={seq_size}, n_modes={n_modes}, l1={final_l1_dist}") + print( + f"[DEBUG] BitSequence seq_size={seq_size}, n_modes={n_modes}, l1={final_l1_dist}" + ) if seq_size == 4 and n_modes == 2: assert final_l1_dist <= 9e-5 if seq_size == 4 and n_modes == 4: diff --git a/tutorials/examples/train_bit_sequences.py b/tutorials/examples/train_bit_sequences.py index 12ad27eb..287ed82b 100644 --- a/tutorials/examples/train_bit_sequences.py +++ b/tutorials/examples/train_bit_sequences.py @@ -26,14 +26,21 @@ def estimated_dist(gflownet: PFBasedGFlowNet, env: BitSequence): def main(args): seed = args.seed if args.seed != 0 else DEFAULT_SEED - set_seed(seed) + set_seed(seed, deterministic_mode=args.deterministic_mode) device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" ) H = torch.randint( 0, 2, (args.n_modes, args.seq_size), dtype=torch.long, device=device ) - env = BitSequence(args.word_size, args.seq_size, args.n_modes, H=H, debug=__debug__) + env = BitSequence( + args.word_size, + args.seq_size, + args.n_modes, + H=H, + seed=seed, + debug=__debug__, + ) if args.loss == "TB": pf = MLP(env.words_per_seq, env.n_actions) @@ -111,6 +118,12 @@ def main(args): parser.add_argument("--loss", type=str, default="TB", help="Loss to use") parser.add_argument("--no_cuda", type=bool, default=True, help="Device to use") parser.add_argument("--seed", type=int, default=0, help="Seed") + parser.add_argument( + "--deterministic_mode", + action="store_true", + default=False, + help="Use deterministic algorithms/threads where possible", + ) parser.add_argument( "--n_iterations", type=int, default=1000, help="Number of iterations" ) From c148a67a16d9be194142cac4d7f1e73babe192b0 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 12 Dec 2025 22:10:27 -0500 Subject: [PATCH 07/16] added benchmark --- tutorials/misc/bench_subtb_get_scores.py | 200 +++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 tutorials/misc/bench_subtb_get_scores.py diff --git a/tutorials/misc/bench_subtb_get_scores.py b/tutorials/misc/bench_subtb_get_scores.py new file mode 100644 index 00000000..d723c87b --- /dev/null +++ b/tutorials/misc/bench_subtb_get_scores.py @@ -0,0 +1,200 @@ +""" +Micro-benchmark for SubTBGFlowNet.get_scores vectorized vs original loop. + +This isolates get_scores by monkeypatching dependencies (calculate_preds/targets, +get_pfs_and_pbs, masks) to synthetic tensors so we can time the core logic. +Run on CPU; adjust sizes below to probe different max_len / batch regimes. +""" + +from __future__ import annotations + +import argparse +from types import MethodType +from typing import Any, Callable, Tuple + +import torch +from torch.utils import benchmark + +from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet + + +class _DummyTrajectories: + """Minimal trajectories carrier for benchmarking get_scores.""" + + def __init__(self, terminating_idx: torch.Tensor, max_length: int): + self.terminating_idx = terminating_idx + self.max_length = max_length + self.n_trajectories = terminating_idx.shape[0] + + def __len__(self) -> int: + return self.n_trajectories + + +def build_model_and_data( + max_len: int, n_traj: int, seed: int = 0 +) -> Tuple[SubTBGFlowNet, _DummyTrajectories, list[torch.Tensor], list[torch.Tensor]]: + torch.manual_seed(seed) + terminating_idx = torch.randint(1, max_len + 1, (n_traj,)) + log_pf_trajectories = torch.randn(max_len, n_traj) + log_pb_trajectories = torch.randn(max_len, n_traj) + log_state_flows = torch.randn(max_len, n_traj) + sink_states_mask = torch.zeros(max_len, n_traj, dtype=torch.bool) + is_terminal_mask = torch.zeros(max_len, n_traj, dtype=torch.bool) + + preds_list = [torch.randn(max_len + 1 - i, n_traj) for i in range(1, max_len + 1)] + targets_list = [torch.randn(max_len + 1 - i, n_traj) for i in range(1, max_len + 1)] + + trajectories = _DummyTrajectories( + terminating_idx=terminating_idx, max_length=max_len + ) + + # Build a SubTBGFlowNet instance without running heavy __init__. + model = SubTBGFlowNet.__new__(SubTBGFlowNet) + torch.nn.Module.__init__(model) + model.debug = False + model.log_reward_clip_min = float("-inf") + + # Monkeypatch the dependencies used inside get_scores to deterministic tensors. + model.get_pfs_and_pbs = MethodType( + lambda self, traj, recalculate_all_logprobs=True: ( + log_pf_trajectories, + log_pb_trajectories, + ), + model, + ) + model.calculate_log_state_flows = MethodType( + lambda self, _env, _traj, _log_pf: log_state_flows, model + ) + model.calculate_masks = MethodType( + lambda self, _log_state_flows, _traj: (sink_states_mask, is_terminal_mask), + model, + ) + model.calculate_preds = MethodType( + lambda self, _log_pf_cum, _log_state_flows, i: preds_list[i - 1], model + ) + model.calculate_targets = MethodType( + lambda self, _traj, _preds, _log_pb_cum, _log_state_flows, _term_mask, _sink_mask, i: targets_list[ + i - 1 + ], + model, + ) + + return model, trajectories, preds_list, targets_list + + +def original_get_scores( + model: SubTBGFlowNet, env, trajectories +) -> Tuple[list[torch.Tensor], list[torch.Tensor]]: + """Reference implementation (pre-vectorized) for comparison.""" + log_pf_trajectories_, log_pb_trajectories_ = model.get_pfs_and_pbs( + trajectories, recalculate_all_logprobs=True + ) + + log_pf_trajectories_cum = model.cumulative_logprobs( + trajectories, log_pf_trajectories_ + ) + log_pb_trajectories_cum = model.cumulative_logprobs( + trajectories, log_pb_trajectories_ + ) + + log_state_flows_ = model.calculate_log_state_flows( + env, trajectories, log_pf_trajectories_ + ) + sink_states_mask_, is_terminal_mask_ = model.calculate_masks( + log_state_flows_, trajectories + ) + + flattening_masks_orig = [] + scores_orig = [] + for i in range(1, 1 + trajectories.max_length): + preds = model.calculate_preds(log_pf_trajectories_cum, log_state_flows_, i) + targets = model.calculate_targets( + trajectories, + preds, + log_pb_trajectories_cum, + log_state_flows_, + is_terminal_mask_, + sink_states_mask_, + i, + ) + + flattening_mask = trajectories.terminating_idx.lt( + torch.arange( + i, + trajectories.max_length + 1, + device=trajectories.terminating_idx.device, + ).unsqueeze(-1) + ) + + flat_preds = preds[~flattening_mask] + if model.debug and torch.any(torch.isnan(flat_preds)): + raise ValueError("NaN in preds") + + flat_targets = targets[~flattening_mask] + if model.debug and torch.any(torch.isnan(flat_targets)): + raise ValueError("NaN in targets") + + flattening_masks_orig.append(flattening_mask) + scores_orig.append(preds - targets) + + return scores_orig, flattening_masks_orig + + +def run_once(mode: str, max_len: int, n_traj: int) -> float: + """Return median time (seconds) for the chosen mode.""" + model, trajectories, _, _ = build_model_and_data(max_len, n_traj) + env_obj: Any = object() + bench: Callable[[], Any] + + if mode == "original": + + def bench_original(): + return original_get_scores(model, env_obj, trajectories) # type: ignore[arg-type] + + bench = bench_original + elif mode == "vectorized": + + def bench_vectorized(): + return model.get_scores(env_obj, trajectories) # type: ignore[arg-type] + + bench = bench_vectorized + else: + raise ValueError(mode) + + t = benchmark.Timer( + stmt="bench()", + globals={"bench": bench}, + setup="", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=0.5) + return t.median + + +def main(): + parser = argparse.ArgumentParser() + # Defaults scaled ~10x to stress larger workloads; override with --sizes if needed. + parser.add_argument( + "--sizes", + nargs="+", + default=["80x640", "160x1280", "320x2560"], + ) + args = parser.parse_args() + + print("Benchmarking SubTBGFlowNet.get_scores (CPU)") + print(f"torch version: {torch.__version__}") + print(f"num threads: {torch.get_num_threads()}") + print() + print(f"{'size':>10} {'orig (ms)':>12} {'vec (ms)':>12} {'speedup':>8}") + + for size in args.sizes: + max_len_s, n_traj_s = size.lower().split("x") + max_len = int(max_len_s) + n_traj = int(n_traj_s) + t_orig = run_once("original", max_len, n_traj) * 1e3 + t_vec = run_once("vectorized", max_len, n_traj) * 1e3 + speedup = t_orig / t_vec if t_vec > 0 else float("inf") + print(f"{size:>10} {t_orig:12.3f} {t_vec:12.3f} {speedup:8.2f}x") + + +if __name__ == "__main__": + main() From 3bbacb0367b6796d84fad7f7f5b1fa56d23b8c8e Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 13 Dec 2025 01:23:11 -0500 Subject: [PATCH 08/16] major speedup --- src/gfn/gflownet/sub_trajectory_balance.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 9d5ddbdf..a18ee9ee 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -377,13 +377,14 @@ def get_scores( ).unsqueeze(-1) ) - flat_preds = preds[~flattening_mask] - if torch.any(torch.isnan(flat_preds)): - raise ValueError("NaN in preds") + if self.debug: + flat_preds = preds[~flattening_mask] + if torch.any(torch.isnan(flat_preds)): + raise ValueError("NaN in preds") - flat_targets = targets[~flattening_mask] - if torch.any(torch.isnan(flat_targets)): - raise ValueError("NaN in targets") + flat_targets = targets[~flattening_mask] + if torch.any(torch.isnan(flat_targets)): + raise ValueError("NaN in targets") flattening_masks.append(flattening_mask) scores.append(preds - targets) From d85d14957220c1aba210d7642caa205a2dad7d91 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 13 Dec 2025 02:55:24 -0500 Subject: [PATCH 09/16] optimized code, good performance on all devices without pure vectorization --- src/gfn/gflownet/sub_trajectory_balance.py | 29 +++--- testing/test_subtb_contributions.py | 109 +++++++++++++++++++++ tutorials/misc/bench_subtb_get_scores.py | 80 ++++++++++++--- 3 files changed, 190 insertions(+), 28 deletions(-) create mode 100644 testing/test_subtb_contributions.py diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index a18ee9ee..c82321cd 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -499,15 +499,17 @@ def get_geometric_within_contributions( Returns: The contributions tensor of shape (max_len * (max_len+1) / 2, n_trajectories). """ - L = self.lamda - max_len = trajectories.max_length t_idx = trajectories.terminating_idx + default_dtype = torch.get_default_dtype() + lam = torch.as_tensor(self.lamda, device=t_idx.device, dtype=default_dtype) + max_len = trajectories.max_length + + arange = torch.arange(max_len, device=t_idx.device, dtype=default_dtype) + pow_lam = lam.pow(arange) # The following tensor represents the weights given to each possible # sub-trajectory length. - contributions = (L ** torch.arange(max_len, device=t_idx.device).double()).to( - torch.get_default_dtype() - ) + contributions = pow_lam contributions = contributions.unsqueeze(-1).repeat(1, len(trajectories)) contributions = contributions.repeat_interleave( torch.arange(max_len, 0, -1, device=t_idx.device), @@ -515,15 +517,18 @@ def get_geometric_within_contributions( output_size=int(max_len * (max_len + 1) / 2), ) - # Now we need to divide each column by n + (n-1) lambda +...+ 1*lambda^{n-1} - # where n is the length of the trajectory corresponding to that column + # Now we need to divide each column by + # sum_{i=0}^{n-1} (n - i) * lambda^i + # where n is the length of the trajectory corresponding to that column. # We can do it the ugly way, or using the cool identity: # https://www.wolframalpha.com/input?i=sum%28%28n-i%29+*+lambda+%5Ei%2C+i%3D0..n%29 - per_trajectory_denom = ( - 1.0 - / (1 - L) ** 2 - * (L * (L ** t_idx.double() - 1) + (1 - L) * t_idx.double()) - ).to(torch.get_default_dtype()) + pow_cumsum = pow_lam.cumsum(0) + i_pow_cumsum = (arange * pow_lam).cumsum(0) + gather_idx = (t_idx.clamp(min=1) - 1).long() + sum_geom = pow_cumsum[gather_idx] + sum_i_geom = i_pow_cumsum[gather_idx] + t_idx_f = t_idx.to(default_dtype) + per_trajectory_denom = t_idx_f * sum_geom - sum_i_geom contributions = contributions / per_trajectory_denom / len(trajectories) return contributions diff --git a/testing/test_subtb_contributions.py b/testing/test_subtb_contributions.py new file mode 100644 index 00000000..de137c32 --- /dev/null +++ b/testing/test_subtb_contributions.py @@ -0,0 +1,109 @@ +from typing import cast + +import pytest +import torch + +from gfn.containers import Trajectories +from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet + + +class DummyTrajectories: + def __init__(self, terminating_idx: torch.Tensor): + self.terminating_idx = terminating_idx + self.max_length = int(torch.max(terminating_idx).item()) + + def __len__(self) -> int: + return self.terminating_idx.numel() + + +def _reference_contributions( + lam: float | torch.Tensor, + terminating_idx: torch.Tensor, + max_len: int, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + lam64 = torch.as_tensor(lam, dtype=torch.float64, device=device) + t_idx64 = terminating_idx.to(dtype=torch.float64) + + base = lam64.pow(torch.arange(max_len, device=device, dtype=torch.float64)) + base = base.unsqueeze(-1).repeat(1, len(terminating_idx)) + base = base.repeat_interleave( + torch.arange(max_len, 0, -1, device=device), + dim=0, + output_size=int(max_len * (max_len + 1) / 2), + ) + + denom = [] + for n in t_idx64.long().tolist(): + ar = torch.arange(n, device=device, dtype=torch.float64) + denom.append(((n - ar) * lam64.pow(ar)).sum()) + denom = torch.stack(denom) + + return (base / denom / len(terminating_idx)).to(dtype) + + +def test_geometric_within_contributions_matches_reference(): + terminating_idx = torch.tensor([1, 2, 3, 5], dtype=torch.int64) + trajectories = DummyTrajectories(terminating_idx) + + model = object.__new__(SubTBGFlowNet) + model.lamda = 0.9 + + result = model.get_geometric_within_contributions( + cast(Trajectories, trajectories) # type: ignore[arg-type] + ) + reference = _reference_contributions( + lam=model.lamda, + terminating_idx=terminating_idx, + max_len=trajectories.max_length, + dtype=result.dtype, + device=terminating_idx.device, + ) + + torch.testing.assert_close(result, reference, rtol=1e-6, atol=1e-6) + + +def test_geometric_within_contributions_near_one_is_stable(): + terminating_idx = torch.tensor([2, 4, 6], dtype=torch.int64) + trajectories = DummyTrajectories(terminating_idx) + + model_near_one = object.__new__(SubTBGFlowNet) + model_near_one.lamda = 1.0 - 1e-4 + + model_exact_one = object.__new__(SubTBGFlowNet) + model_exact_one.lamda = 1.0 + + result_near_one = model_near_one.get_geometric_within_contributions( + cast(Trajectories, trajectories) # type: ignore[arg-type] + ) + result_exact_one = model_exact_one.get_geometric_within_contributions( + cast(Trajectories, trajectories) # type: ignore[arg-type] + ) + + assert torch.isfinite(result_near_one).all() + assert torch.isfinite(result_exact_one).all() + torch.testing.assert_close(result_near_one, result_exact_one, rtol=5e-3, atol=5e-5) + + +@pytest.mark.parametrize("lam", [0.0, 0.3, 0.7, 0.9, 0.999, 1.0]) +def test_geometric_within_contributions_matches_bruteforce(lam: float): + torch.manual_seed(0) + terminating_idx = torch.randint(1, 7, size=(6,), dtype=torch.int64) + trajectories = DummyTrajectories(terminating_idx) + + model = object.__new__(SubTBGFlowNet) + model.lamda = lam + + result = model.get_geometric_within_contributions( + cast(Trajectories, trajectories) # type: ignore[arg-type] + ) + reference = _reference_contributions( + lam=model.lamda, + terminating_idx=terminating_idx, + max_len=trajectories.max_length, + dtype=result.dtype, + device=terminating_idx.device, + ) + + torch.testing.assert_close(result, reference, rtol=1e-6, atol=1e-6) diff --git a/tutorials/misc/bench_subtb_get_scores.py b/tutorials/misc/bench_subtb_get_scores.py index d723c87b..2a0370e7 100644 --- a/tutorials/misc/bench_subtb_get_scores.py +++ b/tutorials/misc/bench_subtb_get_scores.py @@ -31,18 +31,28 @@ def __len__(self) -> int: def build_model_and_data( - max_len: int, n_traj: int, seed: int = 0 + max_len: int, n_traj: int, seed: int = 0, device: str | torch.device | None = None ) -> Tuple[SubTBGFlowNet, _DummyTrajectories, list[torch.Tensor], list[torch.Tensor]]: torch.manual_seed(seed) - terminating_idx = torch.randint(1, max_len + 1, (n_traj,)) - log_pf_trajectories = torch.randn(max_len, n_traj) - log_pb_trajectories = torch.randn(max_len, n_traj) - log_state_flows = torch.randn(max_len, n_traj) - sink_states_mask = torch.zeros(max_len, n_traj, dtype=torch.bool) - is_terminal_mask = torch.zeros(max_len, n_traj, dtype=torch.bool) - - preds_list = [torch.randn(max_len + 1 - i, n_traj) for i in range(1, max_len + 1)] - targets_list = [torch.randn(max_len + 1 - i, n_traj) for i in range(1, max_len + 1)] + device = torch.device(device) if device is not None else torch.device("cpu") + terminating_idx = torch.randint(1, max_len + 1, (n_traj,), device=device) + # In the real pipeline, trajectories carry log_rewards computed from the env. + # The vectorized get_scores now asserts on its presence, so seed a dummy tensor here. + log_rewards = torch.randn(n_traj, device=device) + log_pf_trajectories = torch.randn(max_len, n_traj, device=device) + log_pb_trajectories = torch.randn(max_len, n_traj, device=device) + log_state_flows = torch.randn(max_len, n_traj, device=device) + sink_states_mask = torch.zeros(max_len, n_traj, dtype=torch.bool, device=device) + is_terminal_mask = torch.zeros(max_len, n_traj, dtype=torch.bool, device=device) + + preds_list = [ + torch.randn(max_len + 1 - i, n_traj, device=device) + for i in range(1, max_len + 1) + ] + targets_list = [ + torch.randn(max_len + 1 - i, n_traj, device=device) + for i in range(1, max_len + 1) + ] trajectories = _DummyTrajectories( terminating_idx=terminating_idx, max_length=max_len @@ -69,6 +79,8 @@ def build_model_and_data( lambda self, _log_state_flows, _traj: (sink_states_mask, is_terminal_mask), model, ) + # Attach log_rewards to the dummy trajectories to mirror real trajectories objects. + trajectories.log_rewards = log_rewards model.calculate_preds = MethodType( lambda self, _log_pf_cum, _log_state_flows, i: preds_list[i - 1], model ) @@ -140,11 +152,18 @@ def original_get_scores( return scores_orig, flattening_masks_orig -def run_once(mode: str, max_len: int, n_traj: int) -> float: - """Return median time (seconds) for the chosen mode.""" - model, trajectories, _, _ = build_model_and_data(max_len, n_traj) +def run_once( + mode: str, + max_len: int, + n_traj: int, + use_compile: bool = False, + device: str | torch.device = "cpu", +) -> float: + """Return median time (seconds) for the chosen mode. Optionally uses torch.compile and device selection.""" + model, trajectories, _, _ = build_model_and_data(max_len, n_traj, device=device) env_obj: Any = object() bench: Callable[[], Any] + compiled_get_scores: Callable | None = None if mode == "original": @@ -153,9 +172,19 @@ def bench_original(): bench = bench_original elif mode == "vectorized": + if use_compile: + # Compile only after monkeypatching, so we capture the correct bound method. + compiled_get_scores = torch.compile( + model.get_scores, fullgraph=False, dynamic=False, mode="reduce-overhead" + ) def bench_vectorized(): - return model.get_scores(env_obj, trajectories) # type: ignore[arg-type] + fn = ( + compiled_get_scores + if compiled_get_scores is not None + else model.get_scores + ) + return fn(env_obj, trajectories) # type: ignore[arg-type] bench = bench_vectorized else: @@ -178,6 +207,16 @@ def main(): nargs="+", default=["80x640", "160x1280", "320x2560"], ) + parser.add_argument( + "--compile", + action="store_true", + help="Use torch.compile on the vectorized get_scores.", + ) + parser.add_argument( + "--device", + default="cpu", + help="Device to run on (e.g., cpu, mps, cuda).", + ) args = parser.parse_args() print("Benchmarking SubTBGFlowNet.get_scores (CPU)") @@ -190,8 +229,17 @@ def main(): max_len_s, n_traj_s = size.lower().split("x") max_len = int(max_len_s) n_traj = int(n_traj_s) - t_orig = run_once("original", max_len, n_traj) * 1e3 - t_vec = run_once("vectorized", max_len, n_traj) * 1e3 + t_orig = run_once("original", max_len, n_traj, device=args.device) * 1e3 + t_vec = ( + run_once( + "vectorized", + max_len, + n_traj, + use_compile=args.compile, + device=args.device, + ) + * 1e3 + ) speedup = t_orig / t_vec if t_vec > 0 else float("inf") print(f"{size:>10} {t_orig:12.3f} {t_vec:12.3f} {speedup:8.2f}x") From 953de5a5a5fd5d6fe98fff7be7d67a147227975c Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 13 Dec 2025 03:22:02 -0500 Subject: [PATCH 10/16] get_scores optimized for DetailedBalance --- src/gfn/gflownet/detailed_balance.py | 75 ++++++++++++++++------------ 1 file changed, 44 insertions(+), 31 deletions(-) diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index f29fa919..a618b89a 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -194,6 +194,12 @@ def get_scores( states = transitions.states actions = transitions.actions + conditions = ( + transitions.conditions + ) # reuse locally to avoid repeated attribute lookups + next_states = ( + transitions.next_states + ) # reuse to avoid repeated attribute lookups if len(states) == 0: return torch.tensor(0.0, device=transitions.device) @@ -209,30 +215,39 @@ def get_scores( # Compute log_F_s # LogF is potentially a conditional computation. - if transitions.conditions is not None: + if conditions is not None: with has_conditions_exception_handler("logF", self.logF): - log_F_s = self.logF(states, transitions.conditions).squeeze(-1) + log_F_s = self.logF(states, conditions).squeeze(-1) else: with no_conditions_exception_handler("logF", self.logF): log_F_s = self.logF(states).squeeze(-1) # Compute log_F_s_next - log_F_s_next = torch.zeros_like(log_F_s) + # Preallocate once and fill; write terminating rewards first to avoid an extra zero fill. + log_F_s_next = torch.empty_like(log_F_s) is_terminating = transitions.is_terminating is_intermediate = ~is_terminating - # Assign log_F_s_next for intermediate next states - interm_next_states = transitions.next_states[is_intermediate] - # log_F is potentially a conditional computation. - if transitions.conditions is not None: - with has_conditions_exception_handler("logF", self.logF): - log_F_s_next[is_intermediate] = self.logF( - interm_next_states, - transitions.conditions[is_intermediate], - ).squeeze(-1) - else: - with no_conditions_exception_handler("logF", self.logF): - log_F_s_next[is_intermediate] = self.logF(interm_next_states).squeeze(-1) + # Assign log_F_s_next for terminating transitions directly from clamped rewards. + log_rewards = transitions.log_rewards + assert log_rewards is not None + log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) + log_F_s_next[is_terminating] = log_rewards[is_terminating] + + # Assign log_F_s_next for intermediate next states (skip work if none). + if torch.any(is_intermediate): + interm_idx = is_intermediate.nonzero(as_tuple=True)[0] + interm_next_states = next_states[interm_idx] + # log_F is potentially a conditional computation. + if conditions is not None: + with has_conditions_exception_handler("logF", self.logF): + log_F_s_next[interm_idx] = self.logF( + interm_next_states, + conditions[interm_idx], + ).squeeze(-1) + else: + with no_conditions_exception_handler("logF", self.logF): + log_F_s_next[interm_idx] = self.logF(interm_next_states).squeeze(-1) # Apply forward-looking if applicable if self.forward_looking: @@ -249,27 +264,25 @@ def get_scores( ) # Reward calculation can also be conditional. - if transitions.conditions is not None: - log_rewards_state = env.log_reward(states, transitions.conditions) # type: ignore - log_rewards_next = env.log_reward( - interm_next_states, transitions.conditions[is_intermediate] # type: ignore - ) + if conditions is not None: + log_rewards_state = env.log_reward(states, conditions) # type: ignore else: log_rewards_state = env.log_reward(states) - log_rewards_next = env.log_reward(interm_next_states) - log_rewards_state = log_rewards_state.clamp_min(self.log_reward_clip_min) - log_rewards_next = log_rewards_next.clamp_min(self.log_reward_clip_min) + log_rewards_state = log_rewards_state.clamp_min(self.log_reward_clip_min) log_F_s = log_F_s + log_rewards_state - log_F_s_next[is_intermediate] = ( - log_F_s_next[is_intermediate] + log_rewards_next - ) - # Assign log_F_s_next for terminating transitions as log_rewards - log_rewards = transitions.log_rewards - assert log_rewards is not None - log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) - log_F_s_next[is_terminating] = log_rewards[is_terminating] + if torch.any(is_intermediate): + if conditions is not None: + log_rewards_next = env.log_reward( + next_states[interm_idx], + conditions[interm_idx], # type: ignore + ) + else: + log_rewards_next = env.log_reward(next_states[interm_idx]) + + log_rewards_next = log_rewards_next.clamp_min(self.log_reward_clip_min) + log_F_s_next[interm_idx] = log_F_s_next[interm_idx] + log_rewards_next # Compute scores preds = log_pf + log_F_s From 3c554bab0b1ddf4bb7ae044a98bb00c7850d82cc Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 13 Dec 2025 03:33:50 -0500 Subject: [PATCH 11/16] added some benchmarking scripts (to be amalgamated) --- tutorials/misc/bench_db_get_scores.py | 388 +++++++++++++++++++++++ tutorials/misc/bench_moddb_get_scores.py | 382 ++++++++++++++++++++++ 2 files changed, 770 insertions(+) create mode 100644 tutorials/misc/bench_db_get_scores.py create mode 100644 tutorials/misc/bench_moddb_get_scores.py diff --git a/tutorials/misc/bench_db_get_scores.py b/tutorials/misc/bench_db_get_scores.py new file mode 100644 index 00000000..f2d62440 --- /dev/null +++ b/tutorials/misc/bench_db_get_scores.py @@ -0,0 +1,388 @@ +""" +Micro-benchmark for DBGFlowNet.get_scores (detailed balance) baseline vs +an optimized version. This mirrors the structure of +`tutorials/misc/bench_subtb_get_scores.py` but isolates the transition-based +DB path. +""" + +from __future__ import annotations + +import argparse +from types import MethodType +from typing import Any, Callable, Tuple + +import torch +from torch.utils import benchmark + +from gfn.gflownet.detailed_balance import DBGFlowNet + + +class _DummyStates: + """Minimal stand-in for States; keeps only what get_scores touches.""" + + def __init__(self, tensor: torch.Tensor, is_sink_state: torch.Tensor | None = None): + self.tensor = tensor + self.is_sink_state = ( + is_sink_state + if is_sink_state is not None + else torch.zeros(tensor.shape[0], dtype=torch.bool, device=tensor.device) + ) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_DummyStates": + # Preserve sink-state bookkeeping under boolean or slice indexing. + return _DummyStates(self.tensor[idx], self.is_sink_state[idx]) + + @property + def batch_shape(self) -> torch.Size: + # Matches the check used in check_compatibility when debug is enabled. + return self.tensor.shape[:-1] + + @property + def device(self) -> torch.device: + return self.tensor.device + + +class _DummyActions: + """Minimal stand-in for Actions; only batch_shape and tensor are needed here.""" + + # Keep an exit_action attribute for future compatibility (e.g., Modified DBG). + exit_action = torch.tensor(0, dtype=torch.long) + + def __init__(self, tensor: torch.Tensor): + self.tensor = tensor + self.is_exit = torch.zeros_like(tensor, dtype=torch.bool) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_DummyActions": + return _DummyActions(self.tensor[idx]) + + @property + def batch_shape(self) -> torch.Size: + return self.tensor.shape + + +class _DummyTransitions: + """Carries the attributes touched by DBGFlowNet.get_scores.""" + + def __init__( + self, + states: _DummyStates, + next_states: _DummyStates, + actions: _DummyActions, + is_terminating: torch.Tensor, + log_rewards: torch.Tensor, + conditions: torch.Tensor | None = None, + ): + self.states = states + self.next_states = next_states + self.actions = actions + self.is_terminating = is_terminating + self.is_backward = False + self.conditions = conditions + self.log_rewards = log_rewards + self.device = states.device + self.n_transitions = len(states) + + def __len__(self) -> int: + return self.n_transitions + + +class _DummyEnv: + """Lightweight env wrapper to supply log_reward.""" + + def __init__(self, log_reward_fn: Callable[[Any, Any | None], torch.Tensor]): + self._log_reward_fn = log_reward_fn + + def log_reward(self, states: Any, conditions: Any | None = None) -> torch.Tensor: + return self._log_reward_fn(states, conditions) + + +def build_model_and_data( + n_transitions: int, + seed: int = 0, + device: str | torch.device = "cpu", + forward_looking: bool = False, +) -> Tuple[DBGFlowNet, _DummyEnv, _DummyTransitions]: + """Set up a minimal DBGFlowNet + transitions for benchmarking.""" + torch.manual_seed(seed) + device = torch.device(device) + + # Synthetic data sized to stress memory without extra allocations in the hot path. + states_tensor = torch.randn(n_transitions, 4, device=device) + next_states_tensor = torch.randn(n_transitions, 4, device=device) + is_sink_state = torch.zeros(n_transitions, dtype=torch.bool, device=device) + states = _DummyStates(states_tensor, is_sink_state=is_sink_state) + next_states = _DummyStates(next_states_tensor, is_sink_state=is_sink_state.clone()) + + # Ensure a mix of terminating and intermediate transitions to exercise both branches. + is_terminating = torch.zeros(n_transitions, dtype=torch.bool, device=device) + is_terminating[::3] = True + actions = _DummyActions(torch.zeros(n_transitions, dtype=torch.long, device=device)) + log_rewards = torch.randn(n_transitions, device=device) + + transitions = _DummyTransitions( + states=states, + next_states=next_states, + actions=actions, + is_terminating=is_terminating, + log_rewards=log_rewards, + conditions=None, + ) + + # Precompute tensors so each benchmark iteration avoids fresh allocations. + log_pf = torch.randn(n_transitions, device=device) + log_pb = torch.randn(n_transitions, device=device) + logF_states = torch.randn(n_transitions, 1, device=device) + logF_next = torch.randn(n_transitions, 1, device=device) + log_reward_states = torch.randn(n_transitions, device=device) + log_reward_next = torch.randn(n_transitions, device=device) + + def get_pfs_and_pbs_stub(_self, _transitions, recalculate_all_logprobs: bool = True): + # Fixed tensors keep the timing focused on get_scores compute and masking. + return log_pf, log_pb + + def logF_stub(_self, s, _conditions=None): + # Return shape (..., 1) so the squeeze(-1) in get_scores matches real behavior. + length = len(s) + if length == n_transitions: + return logF_states + return logF_next[:length] + + def log_reward_stub(_states, _conditions=None): + # Forward-looking uses both current and next states; size guides which buffer to use. + length = len(_states) + if length == n_transitions: + return log_reward_states + return log_reward_next[:length] + + env = _DummyEnv(log_reward_stub) + + model = DBGFlowNet.__new__(DBGFlowNet) + torch.nn.Module.__init__(model) + # Minimal attribute set; we bypass __init__ to avoid heavyweight estimator setup. + model.debug = False + model.forward_looking = forward_looking + model.log_reward_clip_min = -float("inf") + model.get_pfs_and_pbs = MethodType(get_pfs_and_pbs_stub, model) + model.logF = MethodType(logF_stub, model) + + return model, env, transitions + + +def original_get_scores( + model: DBGFlowNet, + env: _DummyEnv, + transitions: _DummyTransitions, + recalculate_all_logprobs: bool = True, +) -> torch.Tensor: + """Copy of the current DBGFlowNet.get_scores for baseline timing.""" + # Guard bad inputs under debug to avoid graph breaks in torch.compile. + if model.debug and transitions.is_backward: + raise ValueError("Backward transitions are not supported") + + states = transitions.states + actions = transitions.actions + + if len(states) == 0: + return torch.tensor(0.0, device=transitions.device) + + if model.debug: + from gfn.gflownet.detailed_balance import check_compatibility + + check_compatibility(states, actions, transitions) # type: ignore[arg-type] + assert ( + not transitions.states.is_sink_state.any() + ), "Transition from sink state is not allowed. This is a bug." + + # Compute log_pf and log_pb + log_pf, log_pb = model.get_pfs_and_pbs( + transitions, recalculate_all_logprobs=recalculate_all_logprobs # type: ignore[arg-type] + ) + + # Compute log_F_s + # LogF is potentially a conditional computation. + if transitions.conditions is not None: + from gfn.utils.handlers import has_conditions_exception_handler + + with has_conditions_exception_handler("logF", model.logF): + log_F_s = model.logF(states, transitions.conditions).squeeze(-1) + else: + from gfn.utils.handlers import no_conditions_exception_handler + + with no_conditions_exception_handler("logF", model.logF): + log_F_s = model.logF(states).squeeze(-1) + + # Compute log_F_s_next + log_F_s_next = torch.zeros_like(log_F_s) + is_terminating = transitions.is_terminating + is_intermediate = ~is_terminating + + # Assign log_F_s_next for intermediate next states + interm_next_states = transitions.next_states[is_intermediate] + # log_F is potentially a conditional computation. + if transitions.conditions is not None: + from gfn.utils.handlers import has_conditions_exception_handler + + with has_conditions_exception_handler("logF", model.logF): + log_F_s_next[is_intermediate] = model.logF( + interm_next_states, + transitions.conditions[is_intermediate], + ).squeeze(-1) + else: + from gfn.utils.handlers import no_conditions_exception_handler + + with no_conditions_exception_handler("logF", model.logF): + log_F_s_next[is_intermediate] = model.logF(interm_next_states).squeeze(-1) + + # Apply forward-looking if applicable + if model.forward_looking: + # Reward calculation can also be conditional. + if transitions.conditions is not None: + log_rewards_state = env.log_reward(states, transitions.conditions) # type: ignore + log_rewards_next = env.log_reward( + interm_next_states, transitions.conditions[is_intermediate] # type: ignore + ) + else: + log_rewards_state = env.log_reward(states) + log_rewards_next = env.log_reward(interm_next_states) + + log_rewards_state = log_rewards_state.clamp_min(model.log_reward_clip_min) + log_rewards_next = log_rewards_next.clamp_min(model.log_reward_clip_min) + + log_F_s = log_F_s + log_rewards_state + log_F_s_next[is_intermediate] = log_F_s_next[is_intermediate] + log_rewards_next + + # Assign log_F_s_next for terminating transitions as log_rewards + log_rewards = transitions.log_rewards + assert log_rewards is not None + log_rewards = log_rewards.clamp_min(model.log_reward_clip_min) + log_F_s_next[is_terminating] = log_rewards[is_terminating] + + # Compute scores + preds = log_pf + log_F_s + targets = log_pb + log_F_s_next + scores = preds - targets + assert scores.shape == (transitions.n_transitions,) + return scores + + +def run_once( + mode: str, + n_transitions: int, + forward_looking: bool, + use_compile: bool = False, + device: str | torch.device = "cpu", +) -> float: + """Return median time (seconds) for the chosen mode.""" + model, env, transitions = build_model_and_data( + n_transitions=n_transitions, + forward_looking=forward_looking, + device=device, + ) + + bench: Callable[[], Any] + compiled_get_scores: Callable | None = None + + if mode == "original": + # Use the in-file copy of the current implementation to keep a fixed baseline. + def bench_original(): + return original_get_scores( + model, env, transitions, recalculate_all_logprobs=True + ) + + bench = bench_original + elif mode == "current": + # Benchmarks the method on the model; once optimized, this reflects new code. + if use_compile: + compiled_get_scores = torch.compile( + model.get_scores, + fullgraph=False, + dynamic=False, + mode="reduce-overhead", + ) + + def bench_current(): + fn = ( + compiled_get_scores + if compiled_get_scores is not None + else model.get_scores + ) + return fn(env, transitions) # type: ignore[arg-type] + + bench = bench_current + else: + raise ValueError(mode) + + t = benchmark.Timer( + stmt="bench()", + globals={"bench": bench}, + setup="", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=0.5) + return t.median + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--sizes", + nargs="+", + default=["65536", "131072", "262144"], + help="Number of transitions per batch to benchmark (larger to surface runtime differences).", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Use torch.compile for the current/optimized get_scores.", + ) + parser.add_argument( + "--forward-looking", + action="store_true", + help="Enable forward-looking reward path in the benchmark.", + ) + parser.add_argument( + "--device", + default="cpu", + help="Device to run on (e.g., cpu, mps, cuda).", + ) + args = parser.parse_args() + + print("Benchmarking DBGFlowNet.get_scores (Detailed Balance)") + print(f"torch version: {torch.__version__}") + print(f"num threads: {torch.get_num_threads()}") + print(f"forward-looking: {args.forward_looking}") + print() + print(f"{'n_trans':>10} {'orig (ms)':>12} {'curr (ms)':>12} {'speedup':>8}") + + for size in args.sizes: + n_transitions = int(size) + t_orig = ( + run_once( + "original", + n_transitions, + forward_looking=args.forward_looking, + device=args.device, + ) + * 1e3 + ) + t_curr = ( + run_once( + "current", + n_transitions, + forward_looking=args.forward_looking, + use_compile=args.compile, + device=args.device, + ) + * 1e3 + ) + speedup = t_orig / t_curr if t_curr > 0 else float("inf") + print(f"{n_transitions:10d} {t_orig:12.3f} {t_curr:12.3f} {speedup:8.2f}x") + + +if __name__ == "__main__": + main() diff --git a/tutorials/misc/bench_moddb_get_scores.py b/tutorials/misc/bench_moddb_get_scores.py new file mode 100644 index 00000000..0ea8f0c2 --- /dev/null +++ b/tutorials/misc/bench_moddb_get_scores.py @@ -0,0 +1,382 @@ +""" +Micro-benchmark for ModifiedDBGFlowNet.get_scores baseline vs optimized. +Modeled after bench_db_get_scores.py but targets the modified DB path. +""" + +from __future__ import annotations + +import argparse +from typing import Any, Callable, Tuple + +import torch +from torch.utils import benchmark + +from gfn.gflownet.detailed_balance import ModifiedDBGFlowNet +from gfn.utils.handlers import ( + has_conditions_exception_handler, + no_conditions_exception_handler, +) + + +class _DummyStates: + """Minimal stand-in for States; keeps only what get_scores touches.""" + + def __init__(self, tensor: torch.Tensor, is_sink_state: torch.Tensor | None = None): + self.tensor = tensor + self.is_sink_state = ( + is_sink_state + if is_sink_state is not None + else torch.zeros(tensor.shape[0], dtype=torch.bool, device=tensor.device) + ) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_DummyStates": + return _DummyStates(self.tensor[idx], self.is_sink_state[idx]) + + @property + def device(self) -> torch.device: + return self.tensor.device + + +class _DummyActions: + """Minimal stand-in for Actions; only tensor and is_exit are needed here.""" + + exit_action = torch.tensor(0, dtype=torch.long) + + def __init__(self, tensor: torch.Tensor, is_exit: torch.Tensor | None = None): + self.tensor = tensor + self.is_exit = ( + is_exit + if is_exit is not None + else torch.zeros_like(tensor, dtype=torch.bool) + ) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_DummyActions": + return _DummyActions(self.tensor[idx], self.is_exit[idx]) + + +class _DummyTransitions: + """Carries the attributes touched by ModifiedDBGFlowNet.get_scores.""" + + def __init__( + self, + states: _DummyStates, + next_states: _DummyStates, + actions: _DummyActions, + all_log_rewards: torch.Tensor, + is_backward: bool = False, + log_probs: torch.Tensor | None = None, + has_log_probs: bool = False, + conditions: torch.Tensor | None = None, + ): + self.states = states + self.next_states = next_states + self.actions = actions + self.all_log_rewards = all_log_rewards + self.is_backward = is_backward + self.log_probs = log_probs + self.has_log_probs = has_log_probs + self.conditions = conditions + self.device = states.device + self.n_transitions = len(states) + + def __len__(self) -> int: + return self.n_transitions + + def __getitem__(self, idx) -> "_DummyTransitions": + return _DummyTransitions( + self.states[idx], + self.next_states[idx], + self.actions[idx], + self.all_log_rewards[idx], + self.is_backward, + self.log_probs[idx] if self.log_probs is not None else None, + self.has_log_probs, + self.conditions[idx] if self.conditions is not None else None, + ) + + +class _FakeDist: + """Simple distribution wrapper returning preset log-probs.""" + + def __init__(self, log_action: torch.Tensor, log_exit: torch.Tensor): + self._log_action = log_action + self._log_exit = log_exit + + def log_prob(self, action_tensor: torch.Tensor) -> torch.Tensor: + # Match shape to input; ignore actual action values to focus on timing. + n = action_tensor.shape[0] + # Broadcasting to match shape; slicing guards shorter inputs (next_states path). + if action_tensor.shape == self._log_exit.shape: + return self._log_exit + return self._log_action[:n] + + +class _DummyEstimator: + """Estimator stub providing to_probability_distribution and call signature.""" + + def __init__( + self, + log_action: torch.Tensor, + log_exit: torch.Tensor, + ): + self._log_action = log_action + self._log_exit = log_exit + + def __call__(self, states: _DummyStates, conditions=None): + # Return a placeholder; not used by FakeDist. + return None + + def to_probability_distribution(self, states: _DummyStates, module_output=None): + # Provide a fresh FakeDist per call to mirror API shape; uses preset tensors. + return _FakeDist(self._log_action, self._log_exit) + + +def build_model_and_data( + n_transitions: int, + seed: int = 0, + device: str | torch.device = "cpu", +) -> Tuple[ModifiedDBGFlowNet, _DummyTransitions]: + """Set up a minimal ModifiedDBGFlowNet + transitions for benchmarking.""" + torch.manual_seed(seed) + device = torch.device(device) + + states_tensor = torch.randn(n_transitions, 4, device=device) + next_states_tensor = torch.randn(n_transitions, 4, device=device) + + # Mix of sink/non-sink next states to exercise masking. + is_sink_state = torch.zeros(n_transitions, dtype=torch.bool, device=device) + is_sink_state[::4] = True + states = _DummyStates(states_tensor, is_sink_state=torch.zeros_like(is_sink_state)) + next_states = _DummyStates(next_states_tensor, is_sink_state=is_sink_state) + + # Actions and exits. + actions_tensor = torch.randint(0, 5, (n_transitions,), device=device) + is_exit = torch.zeros_like(actions_tensor, dtype=torch.bool) + actions = _DummyActions(actions_tensor, is_exit=is_exit) + + # Rewards for (state, next_state) pairs as expected by ModifiedDB. + all_log_rewards = torch.randn(n_transitions, 2, device=device) + + transitions = _DummyTransitions( + states=states, + next_states=next_states, + actions=actions, + all_log_rewards=all_log_rewards, + has_log_probs=False, + log_probs=None, + conditions=None, + ) + + # Precomputed log-probs for pf/pb distributions. + # Keep same length as non-sink count to align with mask slices. + non_sink_count = int((~is_sink_state).sum().item()) + log_pf_action = torch.randn(non_sink_count, device=device) + log_pf_exit = torch.randn(non_sink_count, device=device) + log_pf_exit_next = torch.randn(non_sink_count, device=device) + log_pb_action = torch.randn(non_sink_count, device=device) + + pf_estimator = _DummyEstimator(log_pf_action, log_pf_exit) + pb_estimator = _DummyEstimator(log_pb_action, log_pf_exit_next) + + model = ModifiedDBGFlowNet.__new__(ModifiedDBGFlowNet) + torch.nn.Module.__init__(model) + # Minimal attribute set; bypass __init__ to avoid heavy setup. + model.debug = False + model.constant_pb = False + model.pf = pf_estimator + model.pb = pb_estimator + model.log_reward_clip_min = -float("inf") + + return model, transitions + + +def original_get_scores( + model: ModifiedDBGFlowNet, + transitions: _DummyTransitions, + recalculate_all_logprobs: bool = True, +) -> torch.Tensor: + """Copy of ModifiedDBGFlowNet.get_scores for baseline timing.""" + if model.debug and transitions.is_backward: + raise ValueError("Backward transitions are not supported") + + if len(transitions) == 0: + return torch.tensor(0.0, device=transitions.device) + + mask = ~transitions.next_states.is_sink_state + states = transitions.states[mask] + valid_next_states = transitions.next_states[mask] + actions = transitions.actions[mask] + all_log_rewards = transitions.all_log_rewards[mask] + + if model.debug: + from gfn.gflownet.detailed_balance import check_compatibility + + check_compatibility(states, actions, transitions) # type: ignore[arg-type] + + if transitions.conditions is not None: + with has_conditions_exception_handler("pf", model.pf): # type: ignore[name-defined] + module_output = model.pf(states, transitions.conditions[mask]) + else: + with no_conditions_exception_handler("pf", model.pf): # type: ignore[name-defined] + module_output = model.pf(states) + + if len(states) == 0: + return torch.tensor(0.0, device=transitions.device) + + pf_dist = model.pf.to_probability_distribution(states, module_output) # type: ignore[arg-type] + + if transitions.has_log_probs and not recalculate_all_logprobs: + valid_log_pf_actions = transitions[mask].log_probs + assert valid_log_pf_actions is not None + else: + valid_log_pf_actions = pf_dist.log_prob(actions.tensor) + exit_action_tensor = actions.__class__.exit_action.to( + actions.tensor.device, dtype=actions.tensor.dtype + ).expand_as(actions.tensor) + valid_log_pf_s_exit = pf_dist.log_prob(exit_action_tensor) + + if transitions.conditions is not None: + with has_conditions_exception_handler("pf", model.pf): # type: ignore[name-defined] + module_output = model.pf(valid_next_states, transitions.conditions[mask]) + else: + with no_conditions_exception_handler("pf", model.pf): # type: ignore[name-defined] + module_output = model.pf(valid_next_states) + + valid_log_pf_s_prime_exit = model.pf.to_probability_distribution( + valid_next_states, module_output # type: ignore[arg-type] + ).log_prob(exit_action_tensor[: len(valid_next_states)]) + + non_exit_actions = actions[~actions.is_exit] + + if model.pb is not None: + if transitions.conditions is not None: + with has_conditions_exception_handler("pb", model.pb): # type: ignore[name-defined] + module_output = model.pb(valid_next_states, transitions.conditions[mask]) + else: + with no_conditions_exception_handler("pb", model.pb): # type: ignore[name-defined] + module_output = model.pb(valid_next_states) + + valid_log_pb_actions = model.pb.to_probability_distribution( + valid_next_states, module_output # type: ignore[arg-type] + ).log_prob(non_exit_actions.tensor) + else: + valid_log_pb_actions = torch.zeros_like(valid_log_pf_s_exit) + + preds = all_log_rewards[:, 0] + valid_log_pf_actions + valid_log_pf_s_prime_exit + targets = all_log_rewards[:, 1] + valid_log_pb_actions + valid_log_pf_s_exit + + scores = preds - targets + if model.debug and torch.any(torch.isinf(scores)): + raise ValueError("scores contains inf") + + return scores + + +def run_once( + mode: str, + n_transitions: int, + use_compile: bool = False, + device: str | torch.device = "cpu", +) -> float: + """Return median time (seconds) for the chosen mode.""" + model, transitions = build_model_and_data( + n_transitions=n_transitions, + device=device, + ) + + bench: Callable[[], Any] + compiled_get_scores: Callable | None = None + + if mode == "original": + + def bench_original(): + return original_get_scores(model, transitions, recalculate_all_logprobs=True) + + bench = bench_original + elif mode == "current": + if use_compile: + compiled_get_scores = torch.compile( + model.get_scores, + fullgraph=False, + dynamic=False, + mode="reduce-overhead", + ) + + def bench_current(): + fn = ( + compiled_get_scores + if compiled_get_scores is not None + else model.get_scores + ) + return fn(transitions) # type: ignore[arg-type] + + bench = bench_current + else: + raise ValueError(mode) + + t = benchmark.Timer( + stmt="bench()", + globals={"bench": bench}, + setup="", + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=0.5) + return t.median + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--sizes", + nargs="+", + default=["65536", "131072", "262144"], + help="Number of transitions per batch to benchmark (larger to surface runtime differences).", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Use torch.compile for the current/optimized get_scores.", + ) + parser.add_argument( + "--device", + default="cpu", + help="Device to run on (e.g., cpu, mps, cuda).", + ) + args = parser.parse_args() + + print("Benchmarking ModifiedDBGFlowNet.get_scores (Modified DB)") + print(f"torch version: {torch.__version__}") + print(f"num threads: {torch.get_num_threads()}") + print() + print(f"{'n_trans':>10} {'orig (ms)':>12} {'curr (ms)':>12} {'speedup':>8}") + + for size in args.sizes: + n_transitions = int(size) + t_orig = ( + run_once( + "original", + n_transitions, + device=args.device, + ) + * 1e3 + ) + t_curr = ( + run_once( + "current", + n_transitions, + use_compile=args.compile, + device=args.device, + ) + * 1e3 + ) + speedup = t_orig / t_curr if t_curr > 0 else float("inf") + print(f"{n_transitions:10d} {t_orig:12.3f} {t_curr:12.3f} {speedup:8.2f}x") + + +if __name__ == "__main__": + main() From a765178619ff05f73a0a419239c5b3872461460e Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 13 Dec 2025 09:38:57 -0500 Subject: [PATCH 12/16] black --- src/gfn/gflownet/detailed_balance.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index a618b89a..84c5f40c 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -405,18 +405,25 @@ def get_scores( if len(transitions) == 0: return torch.tensor(0.0, device=transitions.device) + conditions = ( + transitions.conditions + ) # reuse locally to avoid repeated attribute lookups mask = ~transitions.next_states.is_sink_state states = transitions.states[mask] valid_next_states = transitions.next_states[mask] actions = transitions.actions[mask] all_log_rewards = transitions.all_log_rewards[mask] + # If no non-sink next states, bail out early to avoid needless estimator calls. + if len(states) == 0: + return torch.tensor(0.0, device=transitions.device) + if self.debug: check_compatibility(states, actions, transitions) - if transitions.conditions is not None: + if conditions is not None: with has_conditions_exception_handler("pf", self.pf): - module_output = self.pf(states, transitions.conditions[mask]) + module_output = self.pf(states, conditions[mask]) else: with no_conditions_exception_handler("pf", self.pf): module_output = self.pf(states) @@ -440,9 +447,9 @@ def get_scores( # The following two lines are slightly inefficient, given that most # next_states are also states, for which we already did a forward pass. - if transitions.conditions is not None: + if conditions is not None: with has_conditions_exception_handler("pf", self.pf): - module_output = self.pf(valid_next_states, transitions.conditions[mask]) + module_output = self.pf(valid_next_states, conditions[mask]) else: with no_conditions_exception_handler("pf", self.pf): module_output = self.pf(valid_next_states) @@ -454,11 +461,9 @@ def get_scores( non_exit_actions = actions[~actions.is_exit] if self.pb is not None: - if transitions.conditions is not None: + if conditions is not None: with has_conditions_exception_handler("pb", self.pb): - module_output = self.pb( - valid_next_states, transitions.conditions[mask] - ) + module_output = self.pb(valid_next_states, conditions[mask]) else: with no_conditions_exception_handler("pb", self.pb): module_output = self.pb(valid_next_states) From 73bbaa58c9a943bd723a37df1863ddd1c6588490 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 13 Dec 2025 10:13:28 -0500 Subject: [PATCH 13/16] very minor modified DB optimizations --- src/gfn/gflownet/detailed_balance.py | 35 +++++++++++++++++----------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 84c5f40c..801cfa7a 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -421,17 +421,19 @@ def get_scores( if self.debug: check_compatibility(states, actions, transitions) + # Single pf forward for current states; reuse the resulting distribution for both + # taken-action log-probs and exit-action log-probs to avoid extra forwards. if conditions is not None: with has_conditions_exception_handler("pf", self.pf): - module_output = self.pf(states, conditions[mask]) + pf_outputs = self.pf(states, conditions[mask]) else: with no_conditions_exception_handler("pf", self.pf): - module_output = self.pf(states) + pf_outputs = self.pf(states) if len(states) == 0: return torch.tensor(0.0, device=transitions.device) - pf_dist = self.pf.to_probability_distribution(states, module_output) + pf_dist = self.pf.to_probability_distribution(states, pf_outputs) if transitions.has_log_probs and not recalculate_all_logprobs: valid_log_pf_actions = transitions[mask].log_probs @@ -445,18 +447,23 @@ def get_scores( ).expand_as(actions.tensor) valid_log_pf_s_exit = pf_dist.log_prob(exit_action_tensor) - # The following two lines are slightly inefficient, given that most - # next_states are also states, for which we already did a forward pass. - if conditions is not None: - with has_conditions_exception_handler("pf", self.pf): - module_output = self.pf(valid_next_states, conditions[mask]) - else: - with no_conditions_exception_handler("pf", self.pf): - module_output = self.pf(valid_next_states) + # Reuse the exit_action tensor and create the next-state distribution once; this + # avoids an additional forward or repeated log_prob calls. + if len(valid_next_states) > 0: + if conditions is not None: + with has_conditions_exception_handler("pf", self.pf): + pf_next_outputs = self.pf(valid_next_states, conditions[mask]) + else: + with no_conditions_exception_handler("pf", self.pf): + pf_next_outputs = self.pf(valid_next_states) - valid_log_pf_s_prime_exit = self.pf.to_probability_distribution( - valid_next_states, module_output - ).log_prob(exit_action_tensor[: len(valid_next_states)]) + pf_next_dist = self.pf.to_probability_distribution( + valid_next_states, pf_next_outputs + ) + valid_log_pf_s_prime_exit = pf_next_dist.log_prob(exit_action_tensor) + else: + # Should be rare because of the early return above; keep shape-friendly zero. + valid_log_pf_s_prime_exit = torch.zeros_like(valid_log_pf_s_exit) non_exit_actions = actions[~actions.is_exit] From c328e63179d45d8cfcb1b74eed18c432a77f7fd9 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 13 Dec 2025 11:24:00 -0500 Subject: [PATCH 14/16] micro-optimizations for tb-style losses --- src/gfn/gflownet/base.py | 20 ++++++++++++++------ src/gfn/gflownet/trajectory_balance.py | 11 ++++++----- src/gfn/utils/prob_calculations.py | 25 ++++++++++++++++--------- 3 files changed, 36 insertions(+), 20 deletions(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 57e4897a..9a463436 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -1,4 +1,3 @@ -import math import warnings from abc import ABC, abstractmethod from typing import Any, Generic, Tuple, TypeVar @@ -381,13 +380,14 @@ def get_scores( ) assert log_pf_trajectories is not None - total_log_pf_trajectories = log_pf_trajectories.sum(dim=0) - total_log_pb_trajectories = log_pb_trajectories.sum(dim=0) + total_log_pf_trajectories = log_pf_trajectories.sum(dim=0) # [N] + total_log_pb_trajectories = log_pb_trajectories.sum(dim=0) # [N] log_rewards = trajectories.log_rewards assert log_rewards is not None - - if math.isfinite(self.log_reward_clip_min): + # Fast path: skip clamp when log_reward_clip_min is -inf to avoid extra work. + # TODO: Do we need log reward clamping at all? + if self.log_reward_clip_min != float("-inf"): log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) # Keep runtime safety checks under `debug` to avoid graph breaks in torch.compile. @@ -399,7 +399,15 @@ def get_scores( assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,) assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,) - return total_log_pf_trajectories - total_log_pb_trajectories - log_rewards + # Fused (pf - pb) then subtract rewards; keep it branch-free/out-of-place + # to stay friendly to torch.compile graphs. + scores = torch.sub( + total_log_pf_trajectories, total_log_pb_trajectories, alpha=1.0 + ) + # Subtract rewards in a separate op to avoid in-place mutations (graph-stable) + # while still keeping only one extra temporary. + scores = scores - log_rewards + return scores def to_training_samples(self, trajectories: Trajectories) -> Trajectories: """Returns the input trajectories as training samples. diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 4cea01fa..3283875b 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -126,12 +126,12 @@ def loss( if trajectories.conditions is not None: with is_callable_exception_handler("logZ", self.logZ): assert isinstance(self.logZ, ScalarEstimator) - logZ = self.logZ(trajectories.conditions) + logZ = self.logZ(trajectories.conditions) # [N] or [..., 1] else: - logZ = self.logZ + logZ = self.logZ # [] - logZ = cast(torch.Tensor, logZ) - scores = (scores + logZ.squeeze()).pow(2) + logZ = cast(torch.Tensor, logZ).squeeze() # [] or [N] + scores = torch.square(scores + logZ) # [N] loss = loss_reduce(scores, reduction) if self.debug and torch.isnan(loss).any(): raise ValueError("loss is nan") @@ -182,7 +182,8 @@ def loss( scores = self.get_scores( trajectories, recalculate_all_logprobs=recalculate_all_logprobs ) - scores = (scores - scores.mean()).pow(2) + scores = scores.sub_(scores.mean()) # [N], in-place mean-centering. + scores = torch.square(scores) # [N] loss = loss_reduce(scores, reduction) if self.debug and torch.isnan(loss).any(): raise ValueError("loss is NaN.") diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 1047dc98..f0db5b38 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -178,19 +178,19 @@ def get_trajectory_pfs( else: # Vectorized path. - log_pf_trajectories = torch.full_like( - trajectories.actions.tensor[..., 0], + # Allocate log_pf explicitly as floating point to avoid silent int casts. + log_pf_trajectories = torch.full( + trajectories.actions.tensor[..., 0].shape, fill_value=fill_value, dtype=torch.get_default_dtype(), + device=trajectories.states.device, ) if len(valid_states) == 0: return log_pf_trajectories # Build conditions per-step shape to align with valid_states - masked_cond = None cond = trajectories.conditions - if cond is not None: T = trajectories.states.tensor.shape[0] # If conditions already has time dim (T, N, ...), index directly. @@ -199,6 +199,8 @@ def get_trajectory_pfs( else: # Broadcast (N, ...) to (T, N, ...), then index. masked_cond = cond.unsqueeze(0).expand((T,) + cond.shape)[state_mask] + else: + masked_cond = None # avoids building an expanded view when unused ctx_v = policy_pf.init_context( int(len(valid_states)), @@ -267,10 +269,12 @@ def get_trajectory_pbs( if trajectories.is_backward: raise ValueError("Backward trajectories are not supported") - log_pb_trajectories = torch.full_like( - trajectories.actions.tensor[..., 0], + # Allocate log_pb explicitly as floating point to avoid silent int casts. + log_pb_trajectories = torch.full( + trajectories.actions.tensor[..., 0].shape, fill_value=fill_value, - dtype=torch.get_default_dtype(), # Floating point dtype. + dtype=torch.get_default_dtype(), + device=trajectories.states.device, ) # Note the different mask for valid states and actions compared to the pf case. @@ -292,7 +296,6 @@ def get_trajectory_pbs( # Using all non-initial states, calculate the backward policy, and the logprobs # of those actions. - masked_cond = None cond = trajectories.conditions if cond is not None: T = trajectories.states.tensor.shape[0] @@ -300,12 +303,16 @@ def get_trajectory_pbs( masked_cond = cond[state_mask] else: masked_cond = cond.unsqueeze(0).expand((T,) + cond.shape)[state_mask] + else: + masked_cond = None # skip broadcasting when no conditions supplied # There is no backward policy in this case. if pb is None: # If pb is None, we assume that the gflownet DAG is a tree, and therefore # the backward policy probability is always 1 (log probs are 0). - valid_log_pb_actions = torch.zeros_like(valid_actions.tensor) + valid_log_pb_actions = torch.zeros_like( + valid_actions.tensor, dtype=torch.get_default_dtype() + ) valid_log_pb_actions = valid_log_pb_actions.squeeze(-1) # no padding. log_pb_trajectories[action_mask] = valid_log_pb_actions.to( log_pb_trajectories.dtype, copy=False From 9eb52c8c99e5b120359b6450f4b40d7bb9299f28 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 13 Dec 2025 11:55:46 -0500 Subject: [PATCH 15/16] added benchmark results for optimization --- tutorials/misc/bench_db_get_scores.py | 388 -------- tutorials/misc/bench_get_scores_all.py | 1151 ++++++++++++++++++++++ tutorials/misc/bench_moddb_get_scores.py | 382 ------- tutorials/misc/bench_subtb_get_scores.py | 248 ----- 4 files changed, 1151 insertions(+), 1018 deletions(-) delete mode 100644 tutorials/misc/bench_db_get_scores.py create mode 100644 tutorials/misc/bench_get_scores_all.py delete mode 100644 tutorials/misc/bench_moddb_get_scores.py delete mode 100644 tutorials/misc/bench_subtb_get_scores.py diff --git a/tutorials/misc/bench_db_get_scores.py b/tutorials/misc/bench_db_get_scores.py deleted file mode 100644 index f2d62440..00000000 --- a/tutorials/misc/bench_db_get_scores.py +++ /dev/null @@ -1,388 +0,0 @@ -""" -Micro-benchmark for DBGFlowNet.get_scores (detailed balance) baseline vs -an optimized version. This mirrors the structure of -`tutorials/misc/bench_subtb_get_scores.py` but isolates the transition-based -DB path. -""" - -from __future__ import annotations - -import argparse -from types import MethodType -from typing import Any, Callable, Tuple - -import torch -from torch.utils import benchmark - -from gfn.gflownet.detailed_balance import DBGFlowNet - - -class _DummyStates: - """Minimal stand-in for States; keeps only what get_scores touches.""" - - def __init__(self, tensor: torch.Tensor, is_sink_state: torch.Tensor | None = None): - self.tensor = tensor - self.is_sink_state = ( - is_sink_state - if is_sink_state is not None - else torch.zeros(tensor.shape[0], dtype=torch.bool, device=tensor.device) - ) - - def __len__(self) -> int: - return self.tensor.shape[0] - - def __getitem__(self, idx) -> "_DummyStates": - # Preserve sink-state bookkeeping under boolean or slice indexing. - return _DummyStates(self.tensor[idx], self.is_sink_state[idx]) - - @property - def batch_shape(self) -> torch.Size: - # Matches the check used in check_compatibility when debug is enabled. - return self.tensor.shape[:-1] - - @property - def device(self) -> torch.device: - return self.tensor.device - - -class _DummyActions: - """Minimal stand-in for Actions; only batch_shape and tensor are needed here.""" - - # Keep an exit_action attribute for future compatibility (e.g., Modified DBG). - exit_action = torch.tensor(0, dtype=torch.long) - - def __init__(self, tensor: torch.Tensor): - self.tensor = tensor - self.is_exit = torch.zeros_like(tensor, dtype=torch.bool) - - def __len__(self) -> int: - return self.tensor.shape[0] - - def __getitem__(self, idx) -> "_DummyActions": - return _DummyActions(self.tensor[idx]) - - @property - def batch_shape(self) -> torch.Size: - return self.tensor.shape - - -class _DummyTransitions: - """Carries the attributes touched by DBGFlowNet.get_scores.""" - - def __init__( - self, - states: _DummyStates, - next_states: _DummyStates, - actions: _DummyActions, - is_terminating: torch.Tensor, - log_rewards: torch.Tensor, - conditions: torch.Tensor | None = None, - ): - self.states = states - self.next_states = next_states - self.actions = actions - self.is_terminating = is_terminating - self.is_backward = False - self.conditions = conditions - self.log_rewards = log_rewards - self.device = states.device - self.n_transitions = len(states) - - def __len__(self) -> int: - return self.n_transitions - - -class _DummyEnv: - """Lightweight env wrapper to supply log_reward.""" - - def __init__(self, log_reward_fn: Callable[[Any, Any | None], torch.Tensor]): - self._log_reward_fn = log_reward_fn - - def log_reward(self, states: Any, conditions: Any | None = None) -> torch.Tensor: - return self._log_reward_fn(states, conditions) - - -def build_model_and_data( - n_transitions: int, - seed: int = 0, - device: str | torch.device = "cpu", - forward_looking: bool = False, -) -> Tuple[DBGFlowNet, _DummyEnv, _DummyTransitions]: - """Set up a minimal DBGFlowNet + transitions for benchmarking.""" - torch.manual_seed(seed) - device = torch.device(device) - - # Synthetic data sized to stress memory without extra allocations in the hot path. - states_tensor = torch.randn(n_transitions, 4, device=device) - next_states_tensor = torch.randn(n_transitions, 4, device=device) - is_sink_state = torch.zeros(n_transitions, dtype=torch.bool, device=device) - states = _DummyStates(states_tensor, is_sink_state=is_sink_state) - next_states = _DummyStates(next_states_tensor, is_sink_state=is_sink_state.clone()) - - # Ensure a mix of terminating and intermediate transitions to exercise both branches. - is_terminating = torch.zeros(n_transitions, dtype=torch.bool, device=device) - is_terminating[::3] = True - actions = _DummyActions(torch.zeros(n_transitions, dtype=torch.long, device=device)) - log_rewards = torch.randn(n_transitions, device=device) - - transitions = _DummyTransitions( - states=states, - next_states=next_states, - actions=actions, - is_terminating=is_terminating, - log_rewards=log_rewards, - conditions=None, - ) - - # Precompute tensors so each benchmark iteration avoids fresh allocations. - log_pf = torch.randn(n_transitions, device=device) - log_pb = torch.randn(n_transitions, device=device) - logF_states = torch.randn(n_transitions, 1, device=device) - logF_next = torch.randn(n_transitions, 1, device=device) - log_reward_states = torch.randn(n_transitions, device=device) - log_reward_next = torch.randn(n_transitions, device=device) - - def get_pfs_and_pbs_stub(_self, _transitions, recalculate_all_logprobs: bool = True): - # Fixed tensors keep the timing focused on get_scores compute and masking. - return log_pf, log_pb - - def logF_stub(_self, s, _conditions=None): - # Return shape (..., 1) so the squeeze(-1) in get_scores matches real behavior. - length = len(s) - if length == n_transitions: - return logF_states - return logF_next[:length] - - def log_reward_stub(_states, _conditions=None): - # Forward-looking uses both current and next states; size guides which buffer to use. - length = len(_states) - if length == n_transitions: - return log_reward_states - return log_reward_next[:length] - - env = _DummyEnv(log_reward_stub) - - model = DBGFlowNet.__new__(DBGFlowNet) - torch.nn.Module.__init__(model) - # Minimal attribute set; we bypass __init__ to avoid heavyweight estimator setup. - model.debug = False - model.forward_looking = forward_looking - model.log_reward_clip_min = -float("inf") - model.get_pfs_and_pbs = MethodType(get_pfs_and_pbs_stub, model) - model.logF = MethodType(logF_stub, model) - - return model, env, transitions - - -def original_get_scores( - model: DBGFlowNet, - env: _DummyEnv, - transitions: _DummyTransitions, - recalculate_all_logprobs: bool = True, -) -> torch.Tensor: - """Copy of the current DBGFlowNet.get_scores for baseline timing.""" - # Guard bad inputs under debug to avoid graph breaks in torch.compile. - if model.debug and transitions.is_backward: - raise ValueError("Backward transitions are not supported") - - states = transitions.states - actions = transitions.actions - - if len(states) == 0: - return torch.tensor(0.0, device=transitions.device) - - if model.debug: - from gfn.gflownet.detailed_balance import check_compatibility - - check_compatibility(states, actions, transitions) # type: ignore[arg-type] - assert ( - not transitions.states.is_sink_state.any() - ), "Transition from sink state is not allowed. This is a bug." - - # Compute log_pf and log_pb - log_pf, log_pb = model.get_pfs_and_pbs( - transitions, recalculate_all_logprobs=recalculate_all_logprobs # type: ignore[arg-type] - ) - - # Compute log_F_s - # LogF is potentially a conditional computation. - if transitions.conditions is not None: - from gfn.utils.handlers import has_conditions_exception_handler - - with has_conditions_exception_handler("logF", model.logF): - log_F_s = model.logF(states, transitions.conditions).squeeze(-1) - else: - from gfn.utils.handlers import no_conditions_exception_handler - - with no_conditions_exception_handler("logF", model.logF): - log_F_s = model.logF(states).squeeze(-1) - - # Compute log_F_s_next - log_F_s_next = torch.zeros_like(log_F_s) - is_terminating = transitions.is_terminating - is_intermediate = ~is_terminating - - # Assign log_F_s_next for intermediate next states - interm_next_states = transitions.next_states[is_intermediate] - # log_F is potentially a conditional computation. - if transitions.conditions is not None: - from gfn.utils.handlers import has_conditions_exception_handler - - with has_conditions_exception_handler("logF", model.logF): - log_F_s_next[is_intermediate] = model.logF( - interm_next_states, - transitions.conditions[is_intermediate], - ).squeeze(-1) - else: - from gfn.utils.handlers import no_conditions_exception_handler - - with no_conditions_exception_handler("logF", model.logF): - log_F_s_next[is_intermediate] = model.logF(interm_next_states).squeeze(-1) - - # Apply forward-looking if applicable - if model.forward_looking: - # Reward calculation can also be conditional. - if transitions.conditions is not None: - log_rewards_state = env.log_reward(states, transitions.conditions) # type: ignore - log_rewards_next = env.log_reward( - interm_next_states, transitions.conditions[is_intermediate] # type: ignore - ) - else: - log_rewards_state = env.log_reward(states) - log_rewards_next = env.log_reward(interm_next_states) - - log_rewards_state = log_rewards_state.clamp_min(model.log_reward_clip_min) - log_rewards_next = log_rewards_next.clamp_min(model.log_reward_clip_min) - - log_F_s = log_F_s + log_rewards_state - log_F_s_next[is_intermediate] = log_F_s_next[is_intermediate] + log_rewards_next - - # Assign log_F_s_next for terminating transitions as log_rewards - log_rewards = transitions.log_rewards - assert log_rewards is not None - log_rewards = log_rewards.clamp_min(model.log_reward_clip_min) - log_F_s_next[is_terminating] = log_rewards[is_terminating] - - # Compute scores - preds = log_pf + log_F_s - targets = log_pb + log_F_s_next - scores = preds - targets - assert scores.shape == (transitions.n_transitions,) - return scores - - -def run_once( - mode: str, - n_transitions: int, - forward_looking: bool, - use_compile: bool = False, - device: str | torch.device = "cpu", -) -> float: - """Return median time (seconds) for the chosen mode.""" - model, env, transitions = build_model_and_data( - n_transitions=n_transitions, - forward_looking=forward_looking, - device=device, - ) - - bench: Callable[[], Any] - compiled_get_scores: Callable | None = None - - if mode == "original": - # Use the in-file copy of the current implementation to keep a fixed baseline. - def bench_original(): - return original_get_scores( - model, env, transitions, recalculate_all_logprobs=True - ) - - bench = bench_original - elif mode == "current": - # Benchmarks the method on the model; once optimized, this reflects new code. - if use_compile: - compiled_get_scores = torch.compile( - model.get_scores, - fullgraph=False, - dynamic=False, - mode="reduce-overhead", - ) - - def bench_current(): - fn = ( - compiled_get_scores - if compiled_get_scores is not None - else model.get_scores - ) - return fn(env, transitions) # type: ignore[arg-type] - - bench = bench_current - else: - raise ValueError(mode) - - t = benchmark.Timer( - stmt="bench()", - globals={"bench": bench}, - setup="", - num_threads=torch.get_num_threads(), - ).blocked_autorange(min_run_time=0.5) - return t.median - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--sizes", - nargs="+", - default=["65536", "131072", "262144"], - help="Number of transitions per batch to benchmark (larger to surface runtime differences).", - ) - parser.add_argument( - "--compile", - action="store_true", - help="Use torch.compile for the current/optimized get_scores.", - ) - parser.add_argument( - "--forward-looking", - action="store_true", - help="Enable forward-looking reward path in the benchmark.", - ) - parser.add_argument( - "--device", - default="cpu", - help="Device to run on (e.g., cpu, mps, cuda).", - ) - args = parser.parse_args() - - print("Benchmarking DBGFlowNet.get_scores (Detailed Balance)") - print(f"torch version: {torch.__version__}") - print(f"num threads: {torch.get_num_threads()}") - print(f"forward-looking: {args.forward_looking}") - print() - print(f"{'n_trans':>10} {'orig (ms)':>12} {'curr (ms)':>12} {'speedup':>8}") - - for size in args.sizes: - n_transitions = int(size) - t_orig = ( - run_once( - "original", - n_transitions, - forward_looking=args.forward_looking, - device=args.device, - ) - * 1e3 - ) - t_curr = ( - run_once( - "current", - n_transitions, - forward_looking=args.forward_looking, - use_compile=args.compile, - device=args.device, - ) - * 1e3 - ) - speedup = t_orig / t_curr if t_curr > 0 else float("inf") - print(f"{n_transitions:10d} {t_orig:12.3f} {t_curr:12.3f} {speedup:8.2f}x") - - -if __name__ == "__main__": - main() diff --git a/tutorials/misc/bench_get_scores_all.py b/tutorials/misc/bench_get_scores_all.py new file mode 100644 index 00000000..950e195d --- /dev/null +++ b/tutorials/misc/bench_get_scores_all.py @@ -0,0 +1,1151 @@ +""" +Unified micro-benchmark runner for GFlowNet losses/get_scores. + +Runs, in order: +- Trajectory Balance (TB) loss +- Log Partition Variance (LPV) loss +- Sub-trajectory Balance (SubTB) get_scores +- Detailed Balance (DB) get_scores +- Modified Detailed Balance (ModDB) get_scores + +For each loss, the script reports four timings: + original (baseline, frozen copy) + original+compile (torch.compile applied to the baseline function) + current (eager) + current+compile (torch.compile applied to the current function) +Two speedups are printed: current/original and current+compile/original. + +All benchmarks use embedded base sizes, scaled by a single --size-scale +multiplier. Correctness is checked once per size before timing and is +excluded from the timing loops. +""" + +from __future__ import annotations + +import argparse +import math +from types import MethodType +from typing import Any, Callable, Iterable, Tuple + +import torch +import torch.nn as nn +from torch.utils import benchmark + +from gfn.gflownet.detailed_balance import DBGFlowNet, ModifiedDBGFlowNet +from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet +from gfn.gflownet.trajectory_balance import ( + LogPartitionVarianceGFlowNet, + TBGFlowNet, +) +from gfn.utils.handlers import ( + has_conditions_exception_handler, + no_conditions_exception_handler, +) + +# ------------------------- +# TB / LPV (loss benchmark) +# ------------------------- + + +class _TBTrajectoriesStub: + def __init__( + self, log_rewards: torch.Tensor, conditions: torch.Tensor | None = None + ): + self._log_rewards = log_rewards + self.n_trajectories = log_rewards.shape[0] + self.conditions = conditions + + @property + def log_rewards(self) -> torch.Tensor: + return self._log_rewards + + +def _build_tb( + T: int, N: int, device: torch.device, dtype: torch.dtype +) -> Tuple[TBGFlowNet, _TBTrajectoriesStub, torch.Tensor, torch.Tensor]: + log_pf = torch.randn(T, N, device=device, dtype=dtype) + log_pb = torch.randn(T, N, device=device, dtype=dtype) + log_rewards = torch.randn(N, device=device, dtype=dtype) + trajectories = _TBTrajectoriesStub(log_rewards) + + model = TBGFlowNet.__new__(TBGFlowNet) + nn.Module.__init__(model) + model.debug = False + model.log_reward_clip_min = -float("inf") + model.logZ = nn.Parameter(torch.tensor(0.0, device=device, dtype=dtype)) + + def _get_pfs_and_pbs(self, _trajectories, recalculate_all_logprobs: bool = True): + return log_pf, log_pb + + model.get_pfs_and_pbs = MethodType(_get_pfs_and_pbs, model) + return model, trajectories, log_pf, log_pb + + +def _build_lpv( + T: int, N: int, device: torch.device, dtype: torch.dtype +) -> Tuple[ + LogPartitionVarianceGFlowNet, _TBTrajectoriesStub, torch.Tensor, torch.Tensor +]: + log_pf = torch.randn(T, N, device=device, dtype=dtype) + log_pb = torch.randn(T, N, device=device, dtype=dtype) + log_rewards = torch.randn(N, device=device, dtype=dtype) + trajectories = _TBTrajectoriesStub(log_rewards) + + model = LogPartitionVarianceGFlowNet.__new__(LogPartitionVarianceGFlowNet) + nn.Module.__init__(model) + model.debug = False + model.log_reward_clip_min = -float("inf") + + def _get_pfs_and_pbs(self, _trajectories, recalculate_all_logprobs: bool = True): + return log_pf, log_pb + + model.get_pfs_and_pbs = MethodType(_get_pfs_and_pbs, model) + return model, trajectories, log_pf, log_pb + + +def _tb_original_get_scores( + model, trajectories, log_pf: torch.Tensor, log_pb: torch.Tensor +): + total_log_pf_trajectories = log_pf.sum(dim=0) + total_log_pb_trajectories = log_pb.sum(dim=0) + + log_rewards = trajectories.log_rewards + if math.isfinite(model.log_reward_clip_min): + log_rewards = log_rewards.clamp_min(model.log_reward_clip_min) + + return total_log_pf_trajectories - total_log_pb_trajectories - log_rewards + + +def _tb_original_loss(model, trajectories, log_pf: torch.Tensor, log_pb: torch.Tensor): + scores = _tb_original_get_scores(model, trajectories, log_pf, log_pb) + logZ = torch.as_tensor(model.logZ).squeeze() + scores = (scores + logZ).pow(2) + return scores.mean() + + +def _lpv_original_loss(model, trajectories, log_pf: torch.Tensor, log_pb: torch.Tensor): + scores = _tb_original_get_scores(model, trajectories, log_pf, log_pb) + centered = scores - scores.mean() + return centered.pow(2).mean() + + +# ------------------------- +# SubTB (get_scores) +# ------------------------- + + +class _SubTBDummyTrajectories: + def __init__(self, terminating_idx: torch.Tensor, max_length: int): + self.terminating_idx = terminating_idx + self.max_length = max_length + self.n_trajectories = terminating_idx.shape[0] + + def __len__(self) -> int: + return self.n_trajectories + + +def _subtb_build_model_and_data( + max_len: int, n_traj: int, seed: int = 0, device: str | torch.device = "cpu" +) -> Tuple[ + SubTBGFlowNet, _SubTBDummyTrajectories, list[torch.Tensor], list[torch.Tensor] +]: + torch.manual_seed(seed) + device = torch.device(device) + terminating_idx = torch.randint(1, max_len + 1, (n_traj,), device=device) + log_rewards = torch.randn(n_traj, device=device) + log_pf_trajectories = torch.randn(max_len, n_traj, device=device) + log_pb_trajectories = torch.randn(max_len, n_traj, device=device) + log_state_flows = torch.randn(max_len, n_traj, device=device) + sink_states_mask = torch.zeros(max_len, n_traj, dtype=torch.bool, device=device) + is_terminal_mask = torch.zeros(max_len, n_traj, dtype=torch.bool, device=device) + + preds_list = [ + torch.randn(max_len + 1 - i, n_traj, device=device) + for i in range(1, max_len + 1) + ] + targets_list = [ + torch.randn(max_len + 1 - i, n_traj, device=device) + for i in range(1, max_len + 1) + ] + + trajectories = _SubTBDummyTrajectories( + terminating_idx=terminating_idx, max_length=max_len + ) + + model = SubTBGFlowNet.__new__(SubTBGFlowNet) + torch.nn.Module.__init__(model) + model.debug = False + model.log_reward_clip_min = float("-inf") + + model.get_pfs_and_pbs = MethodType( + lambda self, traj, recalculate_all_logprobs=True: ( + log_pf_trajectories, + log_pb_trajectories, + ), + model, + ) + model.calculate_log_state_flows = MethodType( + lambda self, _env, _traj, _log_pf: log_state_flows, model + ) + model.calculate_masks = MethodType( + lambda self, _log_state_flows, _traj: (sink_states_mask, is_terminal_mask), + model, + ) + trajectories.log_rewards = log_rewards + model.calculate_preds = MethodType( + lambda self, _log_pf_cum, _log_state_flows, i: preds_list[i - 1], model + ) + model.calculate_targets = MethodType( + lambda self, _traj, _preds, _log_pb_cum, _log_state_flows, _term_mask, _sink_mask, i: targets_list[ + i - 1 + ], + model, + ) + + return model, trajectories, preds_list, targets_list + + +def _subtb_original_get_scores( + model: SubTBGFlowNet, env, trajectories +) -> Tuple[list[torch.Tensor], list[torch.Tensor]]: + log_pf_trajectories_, log_pb_trajectories_ = model.get_pfs_and_pbs( + trajectories, recalculate_all_logprobs=True + ) + + log_pf_trajectories_cum = model.cumulative_logprobs( + trajectories, log_pf_trajectories_ + ) + log_pb_trajectories_cum = model.cumulative_logprobs( + trajectories, log_pb_trajectories_ + ) + + log_state_flows_ = model.calculate_log_state_flows( + env, trajectories, log_pf_trajectories_ + ) + sink_states_mask_, is_terminal_mask_ = model.calculate_masks( + log_state_flows_, trajectories + ) + + flattening_masks_orig = [] + scores_orig = [] + for i in range(1, 1 + trajectories.max_length): + preds = model.calculate_preds(log_pf_trajectories_cum, log_state_flows_, i) + targets = model.calculate_targets( + trajectories, + preds, + log_pb_trajectories_cum, + log_state_flows_, + is_terminal_mask_, + sink_states_mask_, + i, + ) + + flattening_mask = trajectories.terminating_idx.lt( + torch.arange( + i, + trajectories.max_length + 1, + device=trajectories.terminating_idx.device, + ).unsqueeze(-1) + ) + + flat_preds = preds[~flattening_mask] + if model.debug and torch.any(torch.isnan(flat_preds)): + raise ValueError("NaN in preds") + + flat_targets = targets[~flattening_mask] + if model.debug and torch.any(torch.isnan(flat_targets)): + raise ValueError("NaN in targets") + + flattening_masks_orig.append(flattening_mask) + scores_orig.append(preds - targets) + + return scores_orig, flattening_masks_orig + + +# ------------------------- +# DB / ModDB (get_scores) +# ------------------------- + + +class _DBDummyStates: + def __init__(self, tensor: torch.Tensor, is_sink_state: torch.Tensor | None = None): + self.tensor = tensor + self.is_sink_state = ( + is_sink_state + if is_sink_state is not None + else torch.zeros(tensor.shape[0], dtype=torch.bool, device=tensor.device) + ) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_DBDummyStates": + return _DBDummyStates(self.tensor[idx], self.is_sink_state[idx]) + + @property + def batch_shape(self) -> torch.Size: + return self.tensor.shape[:-1] + + @property + def device(self) -> torch.device: + return self.tensor.device + + +class _DBDummyActions: + exit_action = torch.tensor(0, dtype=torch.long) + + def __init__(self, tensor: torch.Tensor): + self.tensor = tensor + self.is_exit = torch.zeros_like(tensor, dtype=torch.bool) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_DBDummyActions": + return _DBDummyActions(self.tensor[idx]) + + @property + def batch_shape(self) -> torch.Size: + return self.tensor.shape + + +class _DBDummyTransitions: + def __init__( + self, + states: _DBDummyStates, + next_states: _DBDummyStates, + actions: _DBDummyActions, + is_terminating: torch.Tensor, + log_rewards: torch.Tensor, + conditions: torch.Tensor | None = None, + ): + self.states = states + self.next_states = next_states + self.actions = actions + self.is_terminating = is_terminating + self.is_backward = False + self.conditions = conditions + self.log_rewards = log_rewards + self.device = states.device + self.n_transitions = len(states) + + def __len__(self) -> int: + return self.n_transitions + + +class _DBDummyEnv: + def __init__(self, log_reward_fn: Callable[[Any, Any | None], torch.Tensor]): + self._log_reward_fn = log_reward_fn + + def log_reward(self, states: Any, conditions: Any | None = None) -> torch.Tensor: + return self._log_reward_fn(states, conditions) + + +def _db_build_model_and_data( + n_transitions: int, + seed: int = 0, + device: str | torch.device = "cpu", + forward_looking: bool = False, +) -> Tuple[DBGFlowNet, _DBDummyEnv, _DBDummyTransitions]: + torch.manual_seed(seed) + device = torch.device(device) + + states_tensor = torch.randn(n_transitions, 4, device=device) + next_states_tensor = torch.randn(n_transitions, 4, device=device) + is_sink_state = torch.zeros(n_transitions, dtype=torch.bool, device=device) + states = _DBDummyStates(states_tensor, is_sink_state=is_sink_state) + next_states = _DBDummyStates(next_states_tensor, is_sink_state=is_sink_state.clone()) + + is_terminating = torch.zeros(n_transitions, dtype=torch.bool, device=device) + is_terminating[::3] = True + actions = _DBDummyActions( + torch.zeros(n_transitions, dtype=torch.long, device=device) + ) + log_rewards = torch.randn(n_transitions, device=device) + + transitions = _DBDummyTransitions( + states=states, + next_states=next_states, + actions=actions, + is_terminating=is_terminating, + log_rewards=log_rewards, + conditions=None, + ) + + log_pf = torch.randn(n_transitions, device=device) + log_pb = torch.randn(n_transitions, device=device) + logF_states = torch.randn(n_transitions, 1, device=device) + logF_next = torch.randn(n_transitions, 1, device=device) + log_reward_states = torch.randn(n_transitions, device=device) + log_reward_next = torch.randn(n_transitions, device=device) + + def get_pfs_and_pbs_stub(_self, _transitions, recalculate_all_logprobs: bool = True): + return log_pf, log_pb + + def logF_stub(_self, s, _conditions=None): + length = len(s) + if length == n_transitions: + return logF_states + return logF_next[:length] + + def log_reward_stub(_states, _conditions=None): + length = len(_states) + if length == n_transitions: + return log_reward_states + return log_reward_next[:length] + + env = _DBDummyEnv(log_reward_stub) + + model = DBGFlowNet.__new__(DBGFlowNet) + torch.nn.Module.__init__(model) + model.debug = False + model.forward_looking = forward_looking + model.log_reward_clip_min = -float("inf") + model.get_pfs_and_pbs = MethodType(get_pfs_and_pbs_stub, model) + model.logF = MethodType(logF_stub, model) + + return model, env, transitions + + +def _db_original_get_scores( + model: DBGFlowNet, + env: _DBDummyEnv, + transitions: _DBDummyTransitions, + recalculate_all_logprobs: bool = True, +) -> torch.Tensor: + if model.debug and transitions.is_backward: + raise ValueError("Backward transitions are not supported") + + states = transitions.states + transitions.actions + + if len(states) == 0: + return torch.tensor(0.0, device=transitions.device) + + log_pf, log_pb = model.get_pfs_and_pbs( + transitions, recalculate_all_logprobs=recalculate_all_logprobs + ) + + if transitions.conditions is not None: + with has_conditions_exception_handler("logF", model.logF): + log_F_s = model.logF(states, transitions.conditions).squeeze(-1) + else: + with no_conditions_exception_handler("logF", model.logF): + log_F_s = model.logF(states).squeeze(-1) + + log_F_s_next = torch.zeros_like(log_F_s) + is_terminating = transitions.is_terminating + is_intermediate = ~is_terminating + + interm_next_states = transitions.next_states[is_intermediate] + if transitions.conditions is not None: + with has_conditions_exception_handler("logF", model.logF): + log_F_s_next[is_intermediate] = model.logF( + interm_next_states, + transitions.conditions[is_intermediate], + ).squeeze(-1) + else: + with no_conditions_exception_handler("logF", model.logF): + log_F_s_next[is_intermediate] = model.logF(interm_next_states).squeeze(-1) + + if model.forward_looking: + if transitions.conditions is not None: + log_rewards_state = env.log_reward(states, transitions.conditions) + log_rewards_next = env.log_reward( + interm_next_states, transitions.conditions[is_intermediate] + ) + else: + log_rewards_state = env.log_reward(states) + log_rewards_next = env.log_reward(interm_next_states) + + log_rewards_state = log_rewards_state.clamp_min(model.log_reward_clip_min) + log_rewards_next = log_rewards_next.clamp_min(model.log_reward_clip_min) + + log_F_s = log_F_s + log_rewards_state + log_F_s_next[is_intermediate] = log_F_s_next[is_intermediate] + log_rewards_next + + log_rewards = transitions.log_rewards + log_rewards = log_rewards.clamp_min(model.log_reward_clip_min) + log_F_s_next[is_terminating] = log_rewards[is_terminating] + + preds = log_pf + log_F_s + targets = log_pb + log_F_s_next + scores = preds - targets + return scores + + +# ----- Modified DB ----- + + +class _ModDBDummyStates: + def __init__(self, tensor: torch.Tensor, is_sink_state: torch.Tensor | None = None): + self.tensor = tensor + self.is_sink_state = ( + is_sink_state + if is_sink_state is not None + else torch.zeros(tensor.shape[0], dtype=torch.bool, device=tensor.device) + ) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_ModDBDummyStates": + return _ModDBDummyStates(self.tensor[idx], self.is_sink_state[idx]) + + @property + def device(self) -> torch.device: + return self.tensor.device + + +class _ModDBDummyActions: + exit_action = torch.tensor(0, dtype=torch.long) + + def __init__(self, tensor: torch.Tensor, is_exit: torch.Tensor | None = None): + self.tensor = tensor + self.is_exit = ( + is_exit + if is_exit is not None + else torch.zeros_like(tensor, dtype=torch.bool) + ) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_ModDBDummyActions": + return _ModDBDummyActions(self.tensor[idx], self.is_exit[idx]) + + +class _ModDBDummyTransitions: + def __init__( + self, + states: _ModDBDummyStates, + next_states: _ModDBDummyStates, + actions: _ModDBDummyActions, + all_log_rewards: torch.Tensor, + is_backward: bool = False, + log_probs: torch.Tensor | None = None, + has_log_probs: bool = False, + conditions: torch.Tensor | None = None, + ): + self.states = states + self.next_states = next_states + self.actions = actions + self.all_log_rewards = all_log_rewards + self.is_backward = is_backward + self.log_probs = log_probs + self.has_log_probs = has_log_probs + self.conditions = conditions + self.device = states.device + self.n_transitions = len(states) + + def __len__(self) -> int: + return self.n_transitions + + def __getitem__(self, idx) -> "_ModDBDummyTransitions": + return _ModDBDummyTransitions( + self.states[idx], + self.next_states[idx], + self.actions[idx], + self.all_log_rewards[idx], + self.is_backward, + self.log_probs[idx] if self.log_probs is not None else None, + self.has_log_probs, + self.conditions[idx] if self.conditions is not None else None, + ) + + +class _ModDBFakeDist: + def __init__(self, log_action: torch.Tensor, log_exit: torch.Tensor): + self._log_action = log_action + self._log_exit = log_exit + + def log_prob(self, action_tensor: torch.Tensor) -> torch.Tensor: + n = action_tensor.shape[0] + if action_tensor.shape == self._log_exit.shape: + return self._log_exit + return self._log_action[:n] + + +class _ModDBDummyEstimator: + def __init__( + self, + log_action: torch.Tensor, + log_exit: torch.Tensor, + ): + self._log_action = log_action + self._log_exit = log_exit + + def __call__(self, states: _ModDBDummyStates, conditions=None): + return None + + def to_probability_distribution(self, states: _ModDBDummyStates, module_output=None): + return _ModDBFakeDist(self._log_action, self._log_exit) + + +def _moddb_build_model_and_data( + n_transitions: int, + seed: int = 0, + device: str | torch.device = "cpu", +) -> Tuple[ModifiedDBGFlowNet, _ModDBDummyTransitions]: + torch.manual_seed(seed) + device = torch.device(device) + + states_tensor = torch.randn(n_transitions, 4, device=device) + next_states_tensor = torch.randn(n_transitions, 4, device=device) + + is_sink_state = torch.zeros(n_transitions, dtype=torch.bool, device=device) + is_sink_state[::4] = True + states = _ModDBDummyStates( + states_tensor, is_sink_state=torch.zeros_like(is_sink_state) + ) + next_states = _ModDBDummyStates(next_states_tensor, is_sink_state=is_sink_state) + + actions_tensor = torch.randint(0, 5, (n_transitions,), device=device) + is_exit = torch.zeros_like(actions_tensor, dtype=torch.bool) + actions = _ModDBDummyActions(actions_tensor, is_exit=is_exit) + + all_log_rewards = torch.randn(n_transitions, 2, device=device) + + transitions = _ModDBDummyTransitions( + states=states, + next_states=next_states, + actions=actions, + all_log_rewards=all_log_rewards, + has_log_probs=False, + log_probs=None, + conditions=None, + ) + + non_sink_count = int((~is_sink_state).sum().item()) + log_pf_action = torch.randn(non_sink_count, device=device) + log_pf_exit = torch.randn(non_sink_count, device=device) + log_pf_exit_next = torch.randn(non_sink_count, device=device) + log_pb_action = torch.randn(non_sink_count, device=device) + + pf_estimator = _ModDBDummyEstimator(log_pf_action, log_pf_exit) + pb_estimator = _ModDBDummyEstimator(log_pb_action, log_pf_exit_next) + + model = ModifiedDBGFlowNet.__new__(ModifiedDBGFlowNet) + torch.nn.Module.__init__(model) + model.debug = False + model.constant_pb = False + model.pf = pf_estimator + model.pb = pb_estimator + model.log_reward_clip_min = -float("inf") + + return model, transitions + + +def _moddb_original_get_scores( + model: ModifiedDBGFlowNet, + transitions: _ModDBDummyTransitions, + recalculate_all_logprobs: bool = True, +) -> torch.Tensor: + if model.debug and transitions.is_backward: + raise ValueError("Backward transitions are not supported") + + if len(transitions) == 0: + return torch.tensor(0.0, device=transitions.device) + + mask = ~transitions.next_states.is_sink_state + states = transitions.states[mask] + valid_next_states = transitions.next_states[mask] + actions = transitions.actions[mask] + all_log_rewards = transitions.all_log_rewards[mask] + + if transitions.conditions is not None: + with has_conditions_exception_handler("pf", model.pf): + module_output = model.pf(states, transitions.conditions[mask]) + else: + with no_conditions_exception_handler("pf", model.pf): + module_output = model.pf(states) + + if len(states) == 0: + return torch.tensor(0.0, device=transitions.device) + + pf_dist = model.pf.to_probability_distribution(states, module_output) + + if transitions.has_log_probs and not recalculate_all_logprobs: + valid_log_pf_actions = transitions[mask].log_probs + assert valid_log_pf_actions is not None + else: + valid_log_pf_actions = pf_dist.log_prob(actions.tensor) + exit_action_tensor = actions.__class__.exit_action.to( + actions.tensor.device, dtype=actions.tensor.dtype + ).expand_as(actions.tensor) + valid_log_pf_s_exit = pf_dist.log_prob(exit_action_tensor) + + if transitions.conditions is not None: + with has_conditions_exception_handler("pf", model.pf): + module_output = model.pf(valid_next_states, transitions.conditions[mask]) + else: + with no_conditions_exception_handler("pf", model.pf): + module_output = model.pf(valid_next_states) + + valid_log_pf_s_prime_exit = model.pf.to_probability_distribution( + valid_next_states, module_output + ).log_prob(exit_action_tensor[: len(valid_next_states)]) + + non_exit_actions = actions[~actions.is_exit] + + if model.pb is not None: + if transitions.conditions is not None: + with has_conditions_exception_handler("pb", model.pb): + module_output = model.pb(valid_next_states, transitions.conditions[mask]) + else: + with no_conditions_exception_handler("pb", model.pb): + module_output = model.pb(valid_next_states) + + valid_log_pb_actions = model.pb.to_probability_distribution( + valid_next_states, module_output + ).log_prob(non_exit_actions.tensor) + else: + valid_log_pb_actions = torch.zeros_like(valid_log_pf_s_exit) + + preds = all_log_rewards[:, 0] + valid_log_pf_actions + valid_log_pf_s_prime_exit + targets = all_log_rewards[:, 1] + valid_log_pb_actions + valid_log_pf_s_exit + + scores = preds - targets + return scores + + +# ------------------------- +# Helpers +# ------------------------- + + +def _select_dtype(name: str, device: torch.device) -> torch.dtype: + if name == "fp32": + return torch.float32 + if name == "fp16": + if device.type != "cuda": + raise ValueError("fp16 is CUDA-only for this benchmark") + return torch.float16 + if name == "bf16": + return torch.bfloat16 + raise ValueError(f"Unsupported dtype {name}") + + +def _scale_int(value: int, scale: float) -> int: + return max(1, int(round(value * scale))) + + +def _scale_pair(pair: Tuple[int, int], scale: float) -> Tuple[int, int]: + return _scale_int(pair[0], scale), _scale_int(pair[1], scale) + + +def _time_fn(fn: Callable[[], Any]) -> float: + t = benchmark.Timer( + stmt="fn()", + globals={"fn": fn}, + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=0.5) + return t.median + + +def _maybe_compile( + fn: Callable[..., Any] | None, enabled: bool +) -> Callable[..., Any] | None: + if fn is None or not enabled: + return None + return torch.compile(fn, fullgraph=False, dynamic=False, mode="reduce-overhead") + + +def _format_ms(value: float | None) -> str: + if value is None: + return " - " + return f"{value*1e3:10.3f}" + + +def _run_with_compile_variants( + eager_fn: Callable[[], Any], + compile_enabled: bool, +) -> tuple[float, float | None]: + t_eager = _time_fn(eager_fn) + compiled_fn = _maybe_compile(eager_fn, compile_enabled) + t_compiled = _time_fn(compiled_fn) if compiled_fn is not None else None + return t_eager, t_compiled + + +def _run_tb_or_lpv( + variant: str, + sizes: Iterable[int], + T: int, + device: torch.device, + dtype: torch.dtype, + repeat: int, + compile_enabled: bool, +): + print(f"\n=== {variant.upper()} loss ===") + print( + f"{'N':>10} {'chk':>4} {'orig(ms)':>10} {'orig+c(ms)':>10} {'curr(ms)':>10} {'curr+c(ms)':>10} {'spd':>6} {'spd_c':>6}" + ) + for size in sizes: + N = int(size) + # correctness (once, not timed) + if variant == "tb": + model, trajectories, log_pf, log_pb = _build_tb(T, N, device, dtype) + orig_val = _tb_original_loss(model, trajectories, log_pf, log_pb) + curr_val = model.loss(None, trajectories, recalculate_all_logprobs=False) # type: ignore[arg-type] + else: + model, trajectories, log_pf, log_pb = _build_lpv(T, N, device, dtype) + orig_val = _lpv_original_loss(model, trajectories, log_pf, log_pb) + curr_val = model.loss(None, trajectories, recalculate_all_logprobs=False) # type: ignore[arg-type] + diff = (orig_val - curr_val).abs().max().item() + tol = 1e-6 if dtype == torch.float32 else 5e-4 + status = "PASS" if diff <= tol else "FAIL" + + orig_times = [] + origc_times = [] + curr_times = [] + currc_times = [] + for _ in range(repeat): + t_orig, t_origc = _run_with_compile_variants( + eager_fn=lambda: ( + _tb_original_loss(model, trajectories, log_pf, log_pb) + if variant == "tb" + else _lpv_original_loss(model, trajectories, log_pf, log_pb) + ), + compile_enabled=compile_enabled, + ) + t_curr, t_currc = _run_with_compile_variants( + eager_fn=lambda: model.loss( + None, trajectories, recalculate_all_logprobs=False # type: ignore[arg-type] + ), + compile_enabled=compile_enabled, + ) + orig_times.append(t_orig) + curr_times.append(t_curr) + if t_origc is not None: + origc_times.append(t_origc) + if t_currc is not None: + currc_times.append(t_currc) + + t_orig_ms = torch.tensor(orig_times).median().item() * 1e3 + t_origc_ms = ( + torch.tensor(origc_times).median().item() * 1e3 if origc_times else None + ) + t_curr_ms = torch.tensor(curr_times).median().item() * 1e3 + t_currc_ms = ( + torch.tensor(currc_times).median().item() * 1e3 if currc_times else None + ) + speedup = t_orig_ms / t_curr_ms if t_curr_ms > 0 else float("inf") + speedup_c = ( + t_orig_ms / t_currc_ms if t_currc_ms and t_currc_ms > 0 else float("inf") + ) + print( + f"{N:10d} {status:>4} {_format_ms(t_orig_ms)} {_format_ms(t_origc_ms)} {_format_ms(t_curr_ms)} {_format_ms(t_currc_ms)} {speedup:6.2f} {speedup_c:6.2f}" + ) + + +def _run_subtb( + sizes: Iterable[Tuple[int, int]], + device: torch.device, + repeat: int, + compile_enabled: bool, +): + print("\n=== SubTB get_scores ===") + print( + f"{'size':>12} {'chk':>4} {'orig(ms)':>10} {'orig+c(ms)':>10} {'curr(ms)':>10} {'curr+c(ms)':>10} {'spd':>6} {'spd_c':>6}" + ) + for max_len, n_traj in sizes: + model, trajectories, _, _ = _subtb_build_model_and_data( + max_len, n_traj, device=device + ) + env_obj: Any = object() + + # correctness (once, not timed) + orig_scores, _ = _subtb_original_get_scores(model, env_obj, trajectories) + curr_scores, _ = model.get_scores(env_obj, trajectories) # type: ignore[arg-type] + max_abs = max( + (orig - curr).abs().max().item() + for orig, curr in zip(orig_scores, curr_scores) + ) + tol = 1e-6 + status = "PASS" if max_abs <= tol else "FAIL" + + orig_times = [] + origc_times = [] + curr_times = [] + currc_times = [] + for _ in range(repeat): + t_orig, t_origc = _run_with_compile_variants( + eager_fn=lambda: _subtb_original_get_scores( + model, env_obj, trajectories + ), + compile_enabled=compile_enabled, + ) + t_curr, t_currc = _run_with_compile_variants( + eager_fn=lambda: model.get_scores(env_obj, trajectories), # type: ignore[arg-type] + compile_enabled=compile_enabled, + ) + orig_times.append(t_orig) + curr_times.append(t_curr) + if t_origc is not None: + origc_times.append(t_origc) + if t_currc is not None: + currc_times.append(t_currc) + + t_orig_ms = torch.tensor(orig_times).median().item() * 1e3 + t_origc_ms = ( + torch.tensor(origc_times).median().item() * 1e3 if origc_times else None + ) + t_curr_ms = torch.tensor(curr_times).median().item() * 1e3 + t_currc_ms = ( + torch.tensor(currc_times).median().item() * 1e3 if currc_times else None + ) + speedup = t_orig_ms / t_curr_ms if t_curr_ms > 0 else float("inf") + speedup_c = ( + t_orig_ms / t_currc_ms if t_currc_ms and t_currc_ms > 0 else float("inf") + ) + print( + f"{max_len}x{n_traj:5d} {status:>4} {_format_ms(t_orig_ms)} {_format_ms(t_origc_ms)} {_format_ms(t_curr_ms)} {_format_ms(t_currc_ms)} {speedup:6.2f} {speedup_c:6.2f}" + ) + + +def _run_db( + sizes: Iterable[int], + device: torch.device, + repeat: int, + compile_enabled: bool, + forward_looking: bool, +): + print("\n=== DB get_scores ===") + print( + f"{'n_trans':>10} {'chk':>4} {'orig(ms)':>10} {'orig+c(ms)':>10} {'curr(ms)':>10} {'curr+c(ms)':>10} {'spd':>6} {'spd_c':>6}" + ) + for n_transitions in sizes: + model, env, transitions = _db_build_model_and_data( + n_transitions=n_transitions, + device=device, + forward_looking=forward_looking, + ) + + # correctness (once, not timed) + orig = _db_original_get_scores( + model, env, transitions, recalculate_all_logprobs=True + ) + curr = model.get_scores(env, transitions) # type: ignore[arg-type] + max_abs = (orig - curr).abs().max().item() + tol = 1e-6 + status = "PASS" if max_abs <= tol else "FAIL" + + orig_times = [] + origc_times = [] + curr_times = [] + currc_times = [] + for _ in range(repeat): + t_orig, t_origc = _run_with_compile_variants( + eager_fn=lambda: _db_original_get_scores( + model, env, transitions, recalculate_all_logprobs=True + ), + compile_enabled=compile_enabled, + ) + t_curr, t_currc = _run_with_compile_variants( + eager_fn=lambda: model.get_scores(env, transitions), # type: ignore[arg-type] + compile_enabled=compile_enabled, + ) + orig_times.append(t_orig) + curr_times.append(t_curr) + if t_origc is not None: + origc_times.append(t_origc) + if t_currc is not None: + currc_times.append(t_currc) + + t_orig_ms = torch.tensor(orig_times).median().item() * 1e3 + t_origc_ms = ( + torch.tensor(origc_times).median().item() * 1e3 if origc_times else None + ) + t_curr_ms = torch.tensor(curr_times).median().item() * 1e3 + t_currc_ms = ( + torch.tensor(currc_times).median().item() * 1e3 if currc_times else None + ) + speedup = t_orig_ms / t_curr_ms if t_curr_ms > 0 else float("inf") + speedup_c = ( + t_orig_ms / t_currc_ms if t_currc_ms and t_currc_ms > 0 else float("inf") + ) + print( + f"{n_transitions:10d} {status:>4} {_format_ms(t_orig_ms)} {_format_ms(t_origc_ms)} {_format_ms(t_curr_ms)} {_format_ms(t_currc_ms)} {speedup:6.2f} {speedup_c:6.2f}" + ) + + +def _run_moddb( + sizes: Iterable[int], + device: torch.device, + repeat: int, + compile_enabled: bool, +): + print("\n=== Modified DB get_scores ===") + print( + f"{'n_trans':>10} {'chk':>4} {'orig(ms)':>10} {'orig+c(ms)':>10} {'curr(ms)':>10} {'curr+c(ms)':>10} {'spd':>6} {'spd_c':>6}" + ) + for n_transitions in sizes: + model, transitions = _moddb_build_model_and_data( + n_transitions=n_transitions, + device=device, + ) + + # correctness (once, not timed) + orig = _moddb_original_get_scores( + model, transitions, recalculate_all_logprobs=True + ) + curr = model.get_scores(transitions) # type: ignore[arg-type] + max_abs = (orig - curr).abs().max().item() + tol = 1e-6 + status = "PASS" if max_abs <= tol else "FAIL" + + orig_times = [] + origc_times = [] + curr_times = [] + currc_times = [] + for _ in range(repeat): + t_orig, t_origc = _run_with_compile_variants( + eager_fn=lambda: _moddb_original_get_scores( + model, transitions, recalculate_all_logprobs=True + ), + compile_enabled=compile_enabled, + ) + t_curr, t_currc = _run_with_compile_variants( + eager_fn=lambda: model.get_scores(transitions), # type: ignore[arg-type] + compile_enabled=compile_enabled, + ) + orig_times.append(t_orig) + curr_times.append(t_curr) + if t_origc is not None: + origc_times.append(t_origc) + if t_currc is not None: + currc_times.append(t_currc) + + t_orig_ms = torch.tensor(orig_times).median().item() * 1e3 + t_origc_ms = ( + torch.tensor(origc_times).median().item() * 1e3 if origc_times else None + ) + t_curr_ms = torch.tensor(curr_times).median().item() * 1e3 + t_currc_ms = ( + torch.tensor(currc_times).median().item() * 1e3 if currc_times else None + ) + speedup = t_orig_ms / t_curr_ms if t_curr_ms > 0 else float("inf") + speedup_c = ( + t_orig_ms / t_currc_ms if t_currc_ms and t_currc_ms > 0 else float("inf") + ) + print( + f"{n_transitions:10d} {status:>4} {_format_ms(t_orig_ms)} {_format_ms(t_origc_ms)} {_format_ms(t_curr_ms)} {_format_ms(t_currc_ms)} {speedup:6.2f} {speedup_c:6.2f}" + ) + + +# ------------------------- +# Main +# ------------------------- + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--device", default="cpu", help="Device to run on (cpu, mps, cuda)." + ) + parser.add_argument( + "--dtype", + choices=["fp32", "fp16", "bf16"], + default="fp32", + help="Floating dtype for TB/LPV (others use fp32).", + ) + parser.add_argument( + "--size-scale", + type=float, + default=1.0, + help="Scale factor applied to all embedded base sizes.", + ) + parser.add_argument( + "--repeat", + type=int, + default=1, + help="Repeat each size multiple times; medians are reported.", + ) + parser.add_argument( + "--compile", + default=True, + action=argparse.BooleanOptionalAction, + help="Enable torch.compile for both original and current functions.", + ) + parser.add_argument( + "--forward-looking", + action="store_true", + help="Enable forward-looking reward path for DB benchmark.", + ) + return parser.parse_args() + + +def main(): + args = _parse_args() + device = torch.device(args.device) + if device.type == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available") + dtype = _select_dtype(args.dtype, device) + + # Base sizes + base_tb_sizes = [10240, 40960, 163840] + base_T = 64 + base_subtb_sizes = [(80, 640), (160, 1280), (320, 2560)] + base_db_sizes = [65536, 131072, 262144] + base_moddb_sizes = [65536, 131072, 262144] + + tb_sizes = [_scale_int(n, args.size_scale) for n in base_tb_sizes] + subtb_sizes = [_scale_pair(p, args.size_scale) for p in base_subtb_sizes] + db_sizes = [_scale_int(n, args.size_scale) for n in base_db_sizes] + moddb_sizes = [_scale_int(n, args.size_scale) for n in base_moddb_sizes] + + print("Benchmarking TB/LPV/SubTB/DB/ModDB in sequence") + print(f"torch version: {torch.__version__}") + print(f"device: {device}") + print(f"dtype (TB/LPV): {dtype}") + print(f"num threads: {torch.get_num_threads()}") + print(f"size-scale: {args.size_scale}") + print(f"compile: {args.compile}") + print(f"repeat: {args.repeat}") + print(f"forward-looking (DB): {args.forward_looking}") + print() + print( + "Columns: original, original+compile, current, current+compile, speedup vs original (eager and compiled)." + ) + + _run_tb_or_lpv( + variant="tb", + sizes=tb_sizes, + T=base_T, + device=device, + dtype=dtype, + repeat=args.repeat, + compile_enabled=args.compile, + ) + _run_tb_or_lpv( + variant="lpv", + sizes=tb_sizes, + T=base_T, + device=device, + dtype=dtype, + repeat=args.repeat, + compile_enabled=args.compile, + ) + _run_subtb( + sizes=subtb_sizes, + device=device, + repeat=args.repeat, + compile_enabled=args.compile, + ) + _run_db( + sizes=db_sizes, + device=device, + repeat=args.repeat, + compile_enabled=args.compile, + forward_looking=args.forward_looking, + ) + _run_moddb( + sizes=moddb_sizes, + device=device, + repeat=args.repeat, + compile_enabled=args.compile, + ) + + +if __name__ == "__main__": + main() diff --git a/tutorials/misc/bench_moddb_get_scores.py b/tutorials/misc/bench_moddb_get_scores.py deleted file mode 100644 index 0ea8f0c2..00000000 --- a/tutorials/misc/bench_moddb_get_scores.py +++ /dev/null @@ -1,382 +0,0 @@ -""" -Micro-benchmark for ModifiedDBGFlowNet.get_scores baseline vs optimized. -Modeled after bench_db_get_scores.py but targets the modified DB path. -""" - -from __future__ import annotations - -import argparse -from typing import Any, Callable, Tuple - -import torch -from torch.utils import benchmark - -from gfn.gflownet.detailed_balance import ModifiedDBGFlowNet -from gfn.utils.handlers import ( - has_conditions_exception_handler, - no_conditions_exception_handler, -) - - -class _DummyStates: - """Minimal stand-in for States; keeps only what get_scores touches.""" - - def __init__(self, tensor: torch.Tensor, is_sink_state: torch.Tensor | None = None): - self.tensor = tensor - self.is_sink_state = ( - is_sink_state - if is_sink_state is not None - else torch.zeros(tensor.shape[0], dtype=torch.bool, device=tensor.device) - ) - - def __len__(self) -> int: - return self.tensor.shape[0] - - def __getitem__(self, idx) -> "_DummyStates": - return _DummyStates(self.tensor[idx], self.is_sink_state[idx]) - - @property - def device(self) -> torch.device: - return self.tensor.device - - -class _DummyActions: - """Minimal stand-in for Actions; only tensor and is_exit are needed here.""" - - exit_action = torch.tensor(0, dtype=torch.long) - - def __init__(self, tensor: torch.Tensor, is_exit: torch.Tensor | None = None): - self.tensor = tensor - self.is_exit = ( - is_exit - if is_exit is not None - else torch.zeros_like(tensor, dtype=torch.bool) - ) - - def __len__(self) -> int: - return self.tensor.shape[0] - - def __getitem__(self, idx) -> "_DummyActions": - return _DummyActions(self.tensor[idx], self.is_exit[idx]) - - -class _DummyTransitions: - """Carries the attributes touched by ModifiedDBGFlowNet.get_scores.""" - - def __init__( - self, - states: _DummyStates, - next_states: _DummyStates, - actions: _DummyActions, - all_log_rewards: torch.Tensor, - is_backward: bool = False, - log_probs: torch.Tensor | None = None, - has_log_probs: bool = False, - conditions: torch.Tensor | None = None, - ): - self.states = states - self.next_states = next_states - self.actions = actions - self.all_log_rewards = all_log_rewards - self.is_backward = is_backward - self.log_probs = log_probs - self.has_log_probs = has_log_probs - self.conditions = conditions - self.device = states.device - self.n_transitions = len(states) - - def __len__(self) -> int: - return self.n_transitions - - def __getitem__(self, idx) -> "_DummyTransitions": - return _DummyTransitions( - self.states[idx], - self.next_states[idx], - self.actions[idx], - self.all_log_rewards[idx], - self.is_backward, - self.log_probs[idx] if self.log_probs is not None else None, - self.has_log_probs, - self.conditions[idx] if self.conditions is not None else None, - ) - - -class _FakeDist: - """Simple distribution wrapper returning preset log-probs.""" - - def __init__(self, log_action: torch.Tensor, log_exit: torch.Tensor): - self._log_action = log_action - self._log_exit = log_exit - - def log_prob(self, action_tensor: torch.Tensor) -> torch.Tensor: - # Match shape to input; ignore actual action values to focus on timing. - n = action_tensor.shape[0] - # Broadcasting to match shape; slicing guards shorter inputs (next_states path). - if action_tensor.shape == self._log_exit.shape: - return self._log_exit - return self._log_action[:n] - - -class _DummyEstimator: - """Estimator stub providing to_probability_distribution and call signature.""" - - def __init__( - self, - log_action: torch.Tensor, - log_exit: torch.Tensor, - ): - self._log_action = log_action - self._log_exit = log_exit - - def __call__(self, states: _DummyStates, conditions=None): - # Return a placeholder; not used by FakeDist. - return None - - def to_probability_distribution(self, states: _DummyStates, module_output=None): - # Provide a fresh FakeDist per call to mirror API shape; uses preset tensors. - return _FakeDist(self._log_action, self._log_exit) - - -def build_model_and_data( - n_transitions: int, - seed: int = 0, - device: str | torch.device = "cpu", -) -> Tuple[ModifiedDBGFlowNet, _DummyTransitions]: - """Set up a minimal ModifiedDBGFlowNet + transitions for benchmarking.""" - torch.manual_seed(seed) - device = torch.device(device) - - states_tensor = torch.randn(n_transitions, 4, device=device) - next_states_tensor = torch.randn(n_transitions, 4, device=device) - - # Mix of sink/non-sink next states to exercise masking. - is_sink_state = torch.zeros(n_transitions, dtype=torch.bool, device=device) - is_sink_state[::4] = True - states = _DummyStates(states_tensor, is_sink_state=torch.zeros_like(is_sink_state)) - next_states = _DummyStates(next_states_tensor, is_sink_state=is_sink_state) - - # Actions and exits. - actions_tensor = torch.randint(0, 5, (n_transitions,), device=device) - is_exit = torch.zeros_like(actions_tensor, dtype=torch.bool) - actions = _DummyActions(actions_tensor, is_exit=is_exit) - - # Rewards for (state, next_state) pairs as expected by ModifiedDB. - all_log_rewards = torch.randn(n_transitions, 2, device=device) - - transitions = _DummyTransitions( - states=states, - next_states=next_states, - actions=actions, - all_log_rewards=all_log_rewards, - has_log_probs=False, - log_probs=None, - conditions=None, - ) - - # Precomputed log-probs for pf/pb distributions. - # Keep same length as non-sink count to align with mask slices. - non_sink_count = int((~is_sink_state).sum().item()) - log_pf_action = torch.randn(non_sink_count, device=device) - log_pf_exit = torch.randn(non_sink_count, device=device) - log_pf_exit_next = torch.randn(non_sink_count, device=device) - log_pb_action = torch.randn(non_sink_count, device=device) - - pf_estimator = _DummyEstimator(log_pf_action, log_pf_exit) - pb_estimator = _DummyEstimator(log_pb_action, log_pf_exit_next) - - model = ModifiedDBGFlowNet.__new__(ModifiedDBGFlowNet) - torch.nn.Module.__init__(model) - # Minimal attribute set; bypass __init__ to avoid heavy setup. - model.debug = False - model.constant_pb = False - model.pf = pf_estimator - model.pb = pb_estimator - model.log_reward_clip_min = -float("inf") - - return model, transitions - - -def original_get_scores( - model: ModifiedDBGFlowNet, - transitions: _DummyTransitions, - recalculate_all_logprobs: bool = True, -) -> torch.Tensor: - """Copy of ModifiedDBGFlowNet.get_scores for baseline timing.""" - if model.debug and transitions.is_backward: - raise ValueError("Backward transitions are not supported") - - if len(transitions) == 0: - return torch.tensor(0.0, device=transitions.device) - - mask = ~transitions.next_states.is_sink_state - states = transitions.states[mask] - valid_next_states = transitions.next_states[mask] - actions = transitions.actions[mask] - all_log_rewards = transitions.all_log_rewards[mask] - - if model.debug: - from gfn.gflownet.detailed_balance import check_compatibility - - check_compatibility(states, actions, transitions) # type: ignore[arg-type] - - if transitions.conditions is not None: - with has_conditions_exception_handler("pf", model.pf): # type: ignore[name-defined] - module_output = model.pf(states, transitions.conditions[mask]) - else: - with no_conditions_exception_handler("pf", model.pf): # type: ignore[name-defined] - module_output = model.pf(states) - - if len(states) == 0: - return torch.tensor(0.0, device=transitions.device) - - pf_dist = model.pf.to_probability_distribution(states, module_output) # type: ignore[arg-type] - - if transitions.has_log_probs and not recalculate_all_logprobs: - valid_log_pf_actions = transitions[mask].log_probs - assert valid_log_pf_actions is not None - else: - valid_log_pf_actions = pf_dist.log_prob(actions.tensor) - exit_action_tensor = actions.__class__.exit_action.to( - actions.tensor.device, dtype=actions.tensor.dtype - ).expand_as(actions.tensor) - valid_log_pf_s_exit = pf_dist.log_prob(exit_action_tensor) - - if transitions.conditions is not None: - with has_conditions_exception_handler("pf", model.pf): # type: ignore[name-defined] - module_output = model.pf(valid_next_states, transitions.conditions[mask]) - else: - with no_conditions_exception_handler("pf", model.pf): # type: ignore[name-defined] - module_output = model.pf(valid_next_states) - - valid_log_pf_s_prime_exit = model.pf.to_probability_distribution( - valid_next_states, module_output # type: ignore[arg-type] - ).log_prob(exit_action_tensor[: len(valid_next_states)]) - - non_exit_actions = actions[~actions.is_exit] - - if model.pb is not None: - if transitions.conditions is not None: - with has_conditions_exception_handler("pb", model.pb): # type: ignore[name-defined] - module_output = model.pb(valid_next_states, transitions.conditions[mask]) - else: - with no_conditions_exception_handler("pb", model.pb): # type: ignore[name-defined] - module_output = model.pb(valid_next_states) - - valid_log_pb_actions = model.pb.to_probability_distribution( - valid_next_states, module_output # type: ignore[arg-type] - ).log_prob(non_exit_actions.tensor) - else: - valid_log_pb_actions = torch.zeros_like(valid_log_pf_s_exit) - - preds = all_log_rewards[:, 0] + valid_log_pf_actions + valid_log_pf_s_prime_exit - targets = all_log_rewards[:, 1] + valid_log_pb_actions + valid_log_pf_s_exit - - scores = preds - targets - if model.debug and torch.any(torch.isinf(scores)): - raise ValueError("scores contains inf") - - return scores - - -def run_once( - mode: str, - n_transitions: int, - use_compile: bool = False, - device: str | torch.device = "cpu", -) -> float: - """Return median time (seconds) for the chosen mode.""" - model, transitions = build_model_and_data( - n_transitions=n_transitions, - device=device, - ) - - bench: Callable[[], Any] - compiled_get_scores: Callable | None = None - - if mode == "original": - - def bench_original(): - return original_get_scores(model, transitions, recalculate_all_logprobs=True) - - bench = bench_original - elif mode == "current": - if use_compile: - compiled_get_scores = torch.compile( - model.get_scores, - fullgraph=False, - dynamic=False, - mode="reduce-overhead", - ) - - def bench_current(): - fn = ( - compiled_get_scores - if compiled_get_scores is not None - else model.get_scores - ) - return fn(transitions) # type: ignore[arg-type] - - bench = bench_current - else: - raise ValueError(mode) - - t = benchmark.Timer( - stmt="bench()", - globals={"bench": bench}, - setup="", - num_threads=torch.get_num_threads(), - ).blocked_autorange(min_run_time=0.5) - return t.median - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--sizes", - nargs="+", - default=["65536", "131072", "262144"], - help="Number of transitions per batch to benchmark (larger to surface runtime differences).", - ) - parser.add_argument( - "--compile", - action="store_true", - help="Use torch.compile for the current/optimized get_scores.", - ) - parser.add_argument( - "--device", - default="cpu", - help="Device to run on (e.g., cpu, mps, cuda).", - ) - args = parser.parse_args() - - print("Benchmarking ModifiedDBGFlowNet.get_scores (Modified DB)") - print(f"torch version: {torch.__version__}") - print(f"num threads: {torch.get_num_threads()}") - print() - print(f"{'n_trans':>10} {'orig (ms)':>12} {'curr (ms)':>12} {'speedup':>8}") - - for size in args.sizes: - n_transitions = int(size) - t_orig = ( - run_once( - "original", - n_transitions, - device=args.device, - ) - * 1e3 - ) - t_curr = ( - run_once( - "current", - n_transitions, - use_compile=args.compile, - device=args.device, - ) - * 1e3 - ) - speedup = t_orig / t_curr if t_curr > 0 else float("inf") - print(f"{n_transitions:10d} {t_orig:12.3f} {t_curr:12.3f} {speedup:8.2f}x") - - -if __name__ == "__main__": - main() diff --git a/tutorials/misc/bench_subtb_get_scores.py b/tutorials/misc/bench_subtb_get_scores.py deleted file mode 100644 index 2a0370e7..00000000 --- a/tutorials/misc/bench_subtb_get_scores.py +++ /dev/null @@ -1,248 +0,0 @@ -""" -Micro-benchmark for SubTBGFlowNet.get_scores vectorized vs original loop. - -This isolates get_scores by monkeypatching dependencies (calculate_preds/targets, -get_pfs_and_pbs, masks) to synthetic tensors so we can time the core logic. -Run on CPU; adjust sizes below to probe different max_len / batch regimes. -""" - -from __future__ import annotations - -import argparse -from types import MethodType -from typing import Any, Callable, Tuple - -import torch -from torch.utils import benchmark - -from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet - - -class _DummyTrajectories: - """Minimal trajectories carrier for benchmarking get_scores.""" - - def __init__(self, terminating_idx: torch.Tensor, max_length: int): - self.terminating_idx = terminating_idx - self.max_length = max_length - self.n_trajectories = terminating_idx.shape[0] - - def __len__(self) -> int: - return self.n_trajectories - - -def build_model_and_data( - max_len: int, n_traj: int, seed: int = 0, device: str | torch.device | None = None -) -> Tuple[SubTBGFlowNet, _DummyTrajectories, list[torch.Tensor], list[torch.Tensor]]: - torch.manual_seed(seed) - device = torch.device(device) if device is not None else torch.device("cpu") - terminating_idx = torch.randint(1, max_len + 1, (n_traj,), device=device) - # In the real pipeline, trajectories carry log_rewards computed from the env. - # The vectorized get_scores now asserts on its presence, so seed a dummy tensor here. - log_rewards = torch.randn(n_traj, device=device) - log_pf_trajectories = torch.randn(max_len, n_traj, device=device) - log_pb_trajectories = torch.randn(max_len, n_traj, device=device) - log_state_flows = torch.randn(max_len, n_traj, device=device) - sink_states_mask = torch.zeros(max_len, n_traj, dtype=torch.bool, device=device) - is_terminal_mask = torch.zeros(max_len, n_traj, dtype=torch.bool, device=device) - - preds_list = [ - torch.randn(max_len + 1 - i, n_traj, device=device) - for i in range(1, max_len + 1) - ] - targets_list = [ - torch.randn(max_len + 1 - i, n_traj, device=device) - for i in range(1, max_len + 1) - ] - - trajectories = _DummyTrajectories( - terminating_idx=terminating_idx, max_length=max_len - ) - - # Build a SubTBGFlowNet instance without running heavy __init__. - model = SubTBGFlowNet.__new__(SubTBGFlowNet) - torch.nn.Module.__init__(model) - model.debug = False - model.log_reward_clip_min = float("-inf") - - # Monkeypatch the dependencies used inside get_scores to deterministic tensors. - model.get_pfs_and_pbs = MethodType( - lambda self, traj, recalculate_all_logprobs=True: ( - log_pf_trajectories, - log_pb_trajectories, - ), - model, - ) - model.calculate_log_state_flows = MethodType( - lambda self, _env, _traj, _log_pf: log_state_flows, model - ) - model.calculate_masks = MethodType( - lambda self, _log_state_flows, _traj: (sink_states_mask, is_terminal_mask), - model, - ) - # Attach log_rewards to the dummy trajectories to mirror real trajectories objects. - trajectories.log_rewards = log_rewards - model.calculate_preds = MethodType( - lambda self, _log_pf_cum, _log_state_flows, i: preds_list[i - 1], model - ) - model.calculate_targets = MethodType( - lambda self, _traj, _preds, _log_pb_cum, _log_state_flows, _term_mask, _sink_mask, i: targets_list[ - i - 1 - ], - model, - ) - - return model, trajectories, preds_list, targets_list - - -def original_get_scores( - model: SubTBGFlowNet, env, trajectories -) -> Tuple[list[torch.Tensor], list[torch.Tensor]]: - """Reference implementation (pre-vectorized) for comparison.""" - log_pf_trajectories_, log_pb_trajectories_ = model.get_pfs_and_pbs( - trajectories, recalculate_all_logprobs=True - ) - - log_pf_trajectories_cum = model.cumulative_logprobs( - trajectories, log_pf_trajectories_ - ) - log_pb_trajectories_cum = model.cumulative_logprobs( - trajectories, log_pb_trajectories_ - ) - - log_state_flows_ = model.calculate_log_state_flows( - env, trajectories, log_pf_trajectories_ - ) - sink_states_mask_, is_terminal_mask_ = model.calculate_masks( - log_state_flows_, trajectories - ) - - flattening_masks_orig = [] - scores_orig = [] - for i in range(1, 1 + trajectories.max_length): - preds = model.calculate_preds(log_pf_trajectories_cum, log_state_flows_, i) - targets = model.calculate_targets( - trajectories, - preds, - log_pb_trajectories_cum, - log_state_flows_, - is_terminal_mask_, - sink_states_mask_, - i, - ) - - flattening_mask = trajectories.terminating_idx.lt( - torch.arange( - i, - trajectories.max_length + 1, - device=trajectories.terminating_idx.device, - ).unsqueeze(-1) - ) - - flat_preds = preds[~flattening_mask] - if model.debug and torch.any(torch.isnan(flat_preds)): - raise ValueError("NaN in preds") - - flat_targets = targets[~flattening_mask] - if model.debug and torch.any(torch.isnan(flat_targets)): - raise ValueError("NaN in targets") - - flattening_masks_orig.append(flattening_mask) - scores_orig.append(preds - targets) - - return scores_orig, flattening_masks_orig - - -def run_once( - mode: str, - max_len: int, - n_traj: int, - use_compile: bool = False, - device: str | torch.device = "cpu", -) -> float: - """Return median time (seconds) for the chosen mode. Optionally uses torch.compile and device selection.""" - model, trajectories, _, _ = build_model_and_data(max_len, n_traj, device=device) - env_obj: Any = object() - bench: Callable[[], Any] - compiled_get_scores: Callable | None = None - - if mode == "original": - - def bench_original(): - return original_get_scores(model, env_obj, trajectories) # type: ignore[arg-type] - - bench = bench_original - elif mode == "vectorized": - if use_compile: - # Compile only after monkeypatching, so we capture the correct bound method. - compiled_get_scores = torch.compile( - model.get_scores, fullgraph=False, dynamic=False, mode="reduce-overhead" - ) - - def bench_vectorized(): - fn = ( - compiled_get_scores - if compiled_get_scores is not None - else model.get_scores - ) - return fn(env_obj, trajectories) # type: ignore[arg-type] - - bench = bench_vectorized - else: - raise ValueError(mode) - - t = benchmark.Timer( - stmt="bench()", - globals={"bench": bench}, - setup="", - num_threads=torch.get_num_threads(), - ).blocked_autorange(min_run_time=0.5) - return t.median - - -def main(): - parser = argparse.ArgumentParser() - # Defaults scaled ~10x to stress larger workloads; override with --sizes if needed. - parser.add_argument( - "--sizes", - nargs="+", - default=["80x640", "160x1280", "320x2560"], - ) - parser.add_argument( - "--compile", - action="store_true", - help="Use torch.compile on the vectorized get_scores.", - ) - parser.add_argument( - "--device", - default="cpu", - help="Device to run on (e.g., cpu, mps, cuda).", - ) - args = parser.parse_args() - - print("Benchmarking SubTBGFlowNet.get_scores (CPU)") - print(f"torch version: {torch.__version__}") - print(f"num threads: {torch.get_num_threads()}") - print() - print(f"{'size':>10} {'orig (ms)':>12} {'vec (ms)':>12} {'speedup':>8}") - - for size in args.sizes: - max_len_s, n_traj_s = size.lower().split("x") - max_len = int(max_len_s) - n_traj = int(n_traj_s) - t_orig = run_once("original", max_len, n_traj, device=args.device) * 1e3 - t_vec = ( - run_once( - "vectorized", - max_len, - n_traj, - use_compile=args.compile, - device=args.device, - ) - * 1e3 - ) - speedup = t_orig / t_vec if t_vec > 0 else float("inf") - print(f"{size:>10} {t_orig:12.3f} {t_vec:12.3f} {speedup:8.2f}x") - - -if __name__ == "__main__": - main() From e0af06f5878bea0051b6d6ad4d3f6530d19c8543 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 13 Dec 2025 12:35:13 -0500 Subject: [PATCH 16/16] update --- testing/test_gflownet.py | 169 +++++++++++++++++++++++++++++++++++- testing/test_gflownets.py | 175 -------------------------------------- 2 files changed, 168 insertions(+), 176 deletions(-) delete mode 100644 testing/test_gflownets.py diff --git a/testing/test_gflownet.py b/testing/test_gflownet.py index 48cbd2e2..135c446e 100644 --- a/testing/test_gflownet.py +++ b/testing/test_gflownet.py @@ -3,12 +3,14 @@ from gfn.containers import StatesContainer, Trajectories from gfn.containers.base import Container -from gfn.estimators import DiscretePolicyEstimator +from gfn.estimators import DiscretePolicyEstimator, ScalarEstimator from gfn.gflownet import FMGFlowNet, TBGFlowNet from gfn.gflownet.base import loss_reduce +from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet from gfn.gym import Box, HyperGrid from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP from gfn.preprocessors import KHotPreprocessor +from gfn.samplers import Sampler from gfn.states import DiscreteStates from gfn.utils.handlers import ( has_conditions_exception_handler, @@ -191,3 +193,168 @@ def flow_matching_loss_original( ) torch.testing.assert_close(loss_vectorized, loss_original) + + +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_subtb_get_scores_vectorized_matches_original(seed: int): + torch.manual_seed(seed) + n_traj = 4 + + # Deterministic HyperGrid env and frozen estimators so real methods can run. + env = HyperGrid(ndim=2, height=3, device="cpu", debug=False) + preproc = KHotPreprocessor(height=env.height, ndim=env.ndim) + + # Tiny MLPs with random weights (frozen for determinism). + module_pf = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions) + module_pb = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions - 1) + module_logF = MLP(input_dim=preproc.output_dim, output_dim=1) + for mod in (module_pf, module_pb, module_logF): + for p in mod.parameters(): + p.requires_grad_(False) + + pf = DiscretePolicyEstimator( + module=module_pf, + n_actions=env.n_actions, + preprocessor=preproc, + is_backward=False, + ) + pb = DiscretePolicyEstimator( + module=module_pb, n_actions=env.n_actions, preprocessor=preproc, is_backward=True + ) + logF = ScalarEstimator(module=module_logF, preprocessor=preproc) + + # Initialize model via __init__ to set up real methods. + model = SubTBGFlowNet( + pf=pf, pb=pb, logF=logF, weighting="geometric_within", lamda=0.9 + ) + model.debug = False + model.log_reward_clip_min = float("-inf") + model.eval() + pf.eval() + pb.eval() + logF.eval() + + # Sample a deterministic batch of trajectories with frozen estimators. + sampler = Sampler(estimator=pf) + trajectories: Trajectories = sampler.sample_trajectories( + env, + n=n_traj, + epsilon=0.0, + save_logprobs=True, + save_estimator_outputs=False, + ) + max_len = trajectories.max_length # noqa: F841 used implicitly by shapes + + def original_get_scores(self, env, trajectories, recalculate_all_logprobs=True): + log_pf_trajectories_, log_pb_trajectories_ = self.get_pfs_and_pbs( + trajectories, recalculate_all_logprobs=recalculate_all_logprobs + ) + + log_pf_trajectories_cum = self.cumulative_logprobs( + trajectories, log_pf_trajectories_ + ) + log_pb_trajectories_cum = self.cumulative_logprobs( + trajectories, log_pb_trajectories_ + ) + + log_state_flows_ = self.calculate_log_state_flows( + env, trajectories, log_pf_trajectories_ + ) + sink_states_mask_, is_terminal_mask_ = self.calculate_masks( + log_state_flows_, trajectories + ) + + flattening_masks_orig = [] + scores_orig = [] + for i in range(1, 1 + trajectories.max_length): + preds = self.calculate_preds(log_pf_trajectories_cum, log_state_flows_, i) + targets = self.calculate_targets( + trajectories, + preds, + log_pb_trajectories_cum, + log_state_flows_, + is_terminal_mask_, + sink_states_mask_, + i, + ) + + flattening_mask = trajectories.terminating_idx.lt( + torch.arange( + i, + trajectories.max_length + 1, + device=trajectories.terminating_idx.device, + ).unsqueeze(-1) + ) + + flat_preds = preds[~flattening_mask] + if self.debug and torch.any(torch.isnan(flat_preds)): + raise ValueError("NaN in preds") + + flat_targets = targets[~flattening_mask] + if self.debug and torch.any(torch.isnan(flat_targets)): + raise ValueError("NaN in targets") + + flattening_masks_orig.append(flattening_mask) + scores_orig.append(preds - targets) + + return scores_orig, flattening_masks_orig + + def normalize_scores_masks( + scores, masks, trajectories: Trajectories + ) -> tuple[torch.Tensor, torch.Tensor]: + """Convert list outputs to padded tensors; pass tensors through unchanged.""" + if isinstance(scores, torch.Tensor): + assert isinstance(masks, torch.Tensor) + return scores, masks + + assert isinstance(scores, (list, tuple)) + assert isinstance(masks, (list, tuple)) + + max_len = trajectories.max_length + n_traj = ( + trajectories.n_trajectories + if hasattr(trajectories, "n_trajectories") + else len(trajectories) + ) + device = trajectories.terminating_idx.device + dtype = scores[0].dtype + + scores_padded = torch.zeros( + (max_len, max_len, n_traj), dtype=dtype, device=device + ) + masks_padded = torch.ones( + (max_len, max_len, n_traj), dtype=torch.bool, device=device + ) + + for i, (s, m) in enumerate(zip(scores, masks), start=1): + seq_len = s.shape[0] + scores_padded[i - 1, :seq_len] = s + masks_padded[i - 1, :seq_len] = m + + return scores_padded, masks_padded + + # Recompute logprobs to ensure PF/PB are evaluated for both paths. + orig_scores_list, orig_masks_list = original_get_scores( + model, env, trajectories, recalculate_all_logprobs=True + ) + vec_scores, vec_masks = model.get_scores( + env, trajectories, recalculate_all_logprobs=True + ) # type: ignore + + vec_scores_t, vec_masks_t = normalize_scores_masks( + vec_scores, vec_masks, trajectories + ) + orig_scores_t, orig_masks_t = normalize_scores_masks( + orig_scores_list, orig_masks_list, trajectories + ) + + valid_mask = ~orig_masks_t + if not torch.allclose( + vec_scores_t[valid_mask], orig_scores_t[valid_mask], equal_nan=True + ): + max_diff = (vec_scores_t[valid_mask] - orig_scores_t[valid_mask]).abs().max() + raise AssertionError( + f"Score mismatch on valid positions; max_abs_diff={max_diff.item()}" + ) + + torch.testing.assert_close(vec_masks_t, orig_masks_t, equal_nan=True) diff --git a/testing/test_gflownets.py b/testing/test_gflownets.py deleted file mode 100644 index e27a7c74..00000000 --- a/testing/test_gflownets.py +++ /dev/null @@ -1,175 +0,0 @@ -import pytest -import torch - -from gfn.containers.trajectories import Trajectories -from gfn.estimators import DiscretePolicyEstimator, ScalarEstimator -from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet -from gfn.gym.hypergrid import HyperGrid -from gfn.preprocessors import KHotPreprocessor -from gfn.samplers import Sampler -from gfn.utils.modules import MLP - - -@pytest.mark.parametrize("seed", [0, 1, 2]) -def test_subtb_get_scores_vectorized_matches_original(seed: int): - torch.manual_seed(seed) - n_traj = 4 - - # Deterministic HyperGrid env and frozen estimators so real methods can run. - env = HyperGrid(ndim=2, height=3, device="cpu", debug=False) - preproc = KHotPreprocessor(height=env.height, ndim=env.ndim) - - # Tiny MLPs with random weights (frozen for determinism). - module_pf = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions) - module_pb = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions - 1) - module_logF = MLP(input_dim=preproc.output_dim, output_dim=1) - for mod in (module_pf, module_pb, module_logF): - for p in mod.parameters(): - p.requires_grad_(False) - - pf = DiscretePolicyEstimator( - module=module_pf, - n_actions=env.n_actions, - preprocessor=preproc, - is_backward=False, - ) - pb = DiscretePolicyEstimator( - module=module_pb, n_actions=env.n_actions, preprocessor=preproc, is_backward=True - ) - logF = ScalarEstimator(module=module_logF, preprocessor=preproc) - - # Initialize model via __init__ to set up real methods. - model = SubTBGFlowNet( - pf=pf, pb=pb, logF=logF, weighting="geometric_within", lamda=0.9 - ) - model.debug = False - model.log_reward_clip_min = float("-inf") - model.eval() - pf.eval() - pb.eval() - logF.eval() - - # Sample a deterministic batch of trajectories with frozen estimators. - sampler = Sampler(estimator=pf) - trajectories: Trajectories = sampler.sample_trajectories( - env, - n=n_traj, - epsilon=0.0, - save_logprobs=True, - save_estimator_outputs=False, - ) - max_len = trajectories.max_length # noqa: F841 used implicitly by shapes - - def original_get_scores(self, env, trajectories, recalculate_all_logprobs=True): - log_pf_trajectories_, log_pb_trajectories_ = self.get_pfs_and_pbs( - trajectories, recalculate_all_logprobs=recalculate_all_logprobs - ) - - log_pf_trajectories_cum = self.cumulative_logprobs( - trajectories, log_pf_trajectories_ - ) - log_pb_trajectories_cum = self.cumulative_logprobs( - trajectories, log_pb_trajectories_ - ) - - log_state_flows_ = self.calculate_log_state_flows( - env, trajectories, log_pf_trajectories_ - ) - sink_states_mask_, is_terminal_mask_ = self.calculate_masks( - log_state_flows_, trajectories - ) - - flattening_masks_orig = [] - scores_orig = [] - for i in range(1, 1 + trajectories.max_length): - preds = self.calculate_preds(log_pf_trajectories_cum, log_state_flows_, i) - targets = self.calculate_targets( - trajectories, - preds, - log_pb_trajectories_cum, - log_state_flows_, - is_terminal_mask_, - sink_states_mask_, - i, - ) - - flattening_mask = trajectories.terminating_idx.lt( - torch.arange( - i, - trajectories.max_length + 1, - device=trajectories.terminating_idx.device, - ).unsqueeze(-1) - ) - - flat_preds = preds[~flattening_mask] - if self.debug and torch.any(torch.isnan(flat_preds)): - raise ValueError("NaN in preds") - - flat_targets = targets[~flattening_mask] - if self.debug and torch.any(torch.isnan(flat_targets)): - raise ValueError("NaN in targets") - - flattening_masks_orig.append(flattening_mask) - scores_orig.append(preds - targets) - - return scores_orig, flattening_masks_orig - - def normalize_scores_masks( - scores, masks, trajectories: Trajectories - ) -> tuple[torch.Tensor, torch.Tensor]: - """Convert list outputs to padded tensors; pass tensors through unchanged.""" - if isinstance(scores, torch.Tensor): - assert isinstance(masks, torch.Tensor) - return scores, masks - - assert isinstance(scores, (list, tuple)) - assert isinstance(masks, (list, tuple)) - - max_len = trajectories.max_length - n_traj = ( - trajectories.n_trajectories - if hasattr(trajectories, "n_trajectories") - else len(trajectories) - ) - device = trajectories.terminating_idx.device - dtype = scores[0].dtype - - scores_padded = torch.zeros( - (max_len, max_len, n_traj), dtype=dtype, device=device - ) - masks_padded = torch.ones( - (max_len, max_len, n_traj), dtype=torch.bool, device=device - ) - - for i, (s, m) in enumerate(zip(scores, masks), start=1): - seq_len = s.shape[0] - scores_padded[i - 1, :seq_len] = s - masks_padded[i - 1, :seq_len] = m - - return scores_padded, masks_padded - - # Recompute logprobs to ensure PF/PB are evaluated for both paths. - orig_scores_list, orig_masks_list = original_get_scores( - model, env, trajectories, recalculate_all_logprobs=True - ) - vec_scores, vec_masks = model.get_scores( - env, trajectories, recalculate_all_logprobs=True - ) # type: ignore - - vec_scores_t, vec_masks_t = normalize_scores_masks( - vec_scores, vec_masks, trajectories - ) - orig_scores_t, orig_masks_t = normalize_scores_masks( - orig_scores_list, orig_masks_list, trajectories - ) - - valid_mask = ~orig_masks_t - if not torch.allclose( - vec_scores_t[valid_mask], orig_scores_t[valid_mask], equal_nan=True - ): - max_diff = (vec_scores_t[valid_mask] - orig_scores_t[valid_mask]).abs().max() - raise AssertionError( - f"Score mismatch on valid positions; max_abs_diff={max_diff.item()}" - ) - - torch.testing.assert_close(vec_masks_t, orig_masks_t, equal_nan=True)