diff --git a/docs/source/guides/losses.md b/docs/source/guides/losses.md index 45c79f15..b6535b18 100644 --- a/docs/source/guides/losses.md +++ b/docs/source/guides/losses.md @@ -5,7 +5,7 @@ GFlowNets can be trained with different losses, each of which requires a differe Currently, the implemented losses are: - Flow Matching: This is the original loss, and is mostly made available for completeness. It is slow to compute this loss, and also hard to optimize, so is generally not recommended for any significantly hard to learn problem. -- Detailed Balance (and it's modified variant). +- Detailed Balance (and its modified variant). In forward-looking mode, rewards must be defined on edges; the current implementation treats the edge reward as the difference between the successor and current state rewards, so only enable this when that matches your environment. - Trajectory Balance -- Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined [here](https://www.semanticscholar.org/reader/f2c32fe3f7f3e2e9d36d833e32ec55fc93f900f5). Other strategies exist and are implemented [here](https://github.com/gfnorg/torchgfn/tree/master/src/gfn/losses/sub_trajectory_balance.py). +- Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined [here](https://www.semanticscholar.org/reader/f2c32fe3f7f3e2e9d36d833e32ec55fc93f900f5). Other strategies exist and are implemented [here](https://github.com/gfnorg/torchgfn/tree/master/src/gfn/losses/sub_trajectory_balance.py). When using geometric-based weighting, the 'mean' reduction is not supported; requests for a mean reduction are coerced to a sum (a warning is emitted when debug is enabled). - Log Partition Variance loss. Introduced [here](https://arxiv.org/abs/2302.05446) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 5542019b..9a463436 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -1,4 +1,3 @@ -import math import warnings from abc import ABC, abstractmethod from typing import Any, Generic, Tuple, TypeVar @@ -48,6 +47,16 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]): log_reward_clip_min = float("-inf") # Default off. + def __init__(self, debug: bool = False) -> None: + """Initialize shared GFlowNet state. + + Args: + debug: If True, keep runtime safety checks and warnings active. Set False + in compiled hot paths to avoid graph breaks; use True in tests/debugging. + """ + super().__init__() + self.debug = debug + @abstractmethod def sample_trajectories( self, @@ -148,6 +157,8 @@ def loss_from_trajectories( def assert_finite_gradients(self): """Asserts that the gradients are finite.""" + if not self.debug: + return for p in self.parameters(): if p.grad is not None: if not torch.isfinite(p.grad).all(): @@ -155,6 +166,8 @@ def assert_finite_gradients(self): def assert_finite_parameters(self): """Asserts that the parameters are finite.""" + if not self.debug: + return for p in self.parameters(): if not torch.isfinite(p).all(): raise RuntimeError("GFlowNet has non-finite parameters") @@ -177,6 +190,7 @@ def __init__( pb: Estimator | None, constant_pb: bool = False, log_reward_clip_min: float = float("-inf"), + debug: bool = False, ) -> None: """Initializes a PFBasedGFlowNet instance. @@ -189,9 +203,10 @@ def __init__( explicitly by user to ensure that pb is an Estimator except under this special case. log_reward_clip_min: If finite, clips log rewards to this value. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ - super().__init__() + super().__init__(debug=debug) # Technical note: pb may be constant for a variety of edge cases, for example, # if all terminal states can be reached with exactly the same number of # trajectories, and we assume a uniform backward policy, then we can omit the pb @@ -365,23 +380,34 @@ def get_scores( ) assert log_pf_trajectories is not None - total_log_pf_trajectories = log_pf_trajectories.sum(dim=0) - total_log_pb_trajectories = log_pb_trajectories.sum(dim=0) + total_log_pf_trajectories = log_pf_trajectories.sum(dim=0) # [N] + total_log_pb_trajectories = log_pb_trajectories.sum(dim=0) # [N] log_rewards = trajectories.log_rewards assert log_rewards is not None - - if math.isfinite(self.log_reward_clip_min): + # Fast path: skip clamp when log_reward_clip_min is -inf to avoid extra work. + # TODO: Do we need log reward clamping at all? + if self.log_reward_clip_min != float("-inf"): log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) - if torch.any(torch.isinf(total_log_pf_trajectories)): - raise ValueError("Infinite pf logprobs found") - if torch.any(torch.isinf(total_log_pb_trajectories)): - raise ValueError("Infinite pb logprobs found") - - assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,) - assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,) - return total_log_pf_trajectories - total_log_pb_trajectories - log_rewards + # Keep runtime safety checks under `debug` to avoid graph breaks in torch.compile. + if self.debug: + if torch.any(torch.isinf(total_log_pf_trajectories)): + raise ValueError("Infinite pf logprobs found") + if torch.any(torch.isinf(total_log_pb_trajectories)): + raise ValueError("Infinite pb logprobs found") + assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,) + assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,) + + # Fused (pf - pb) then subtract rewards; keep it branch-free/out-of-place + # to stay friendly to torch.compile graphs. + scores = torch.sub( + total_log_pf_trajectories, total_log_pb_trajectories, alpha=1.0 + ) + # Subtract rewards in a separate op to avoid in-place mutations (graph-stable) + # while still keeping only one extra temporary. + scores = scores - log_rewards + return scores def to_training_samples(self, trajectories: Trajectories) -> Trajectories: """Returns the input trajectories as training samples. diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index feff5462..801cfa7a 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -1,4 +1,3 @@ -import math from typing import Tuple import torch @@ -59,7 +58,10 @@ class DBGFlowNet(PFBasedGFlowNet[Transitions]): pb: The backward policy estimator. logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states. - forward_looking: Whether to use the forward-looking GFN loss. + forward_looking: Whether to use the forward-looking GFN loss. When True, + rewards must be defined over edges; this implementation treats the edge + reward as the difference between the successor and current state rewards, + so only valid if the environment follows that assumption. constant_pb: Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. log_reward_clip_min: If finite, clips log rewards to this value. @@ -73,6 +75,7 @@ def __init__( forward_looking: bool = False, constant_pb: bool = False, log_reward_clip_min: float = -float("inf"), + debug: bool = False, ) -> None: """Initializes a DBGFlowNet instance. @@ -82,16 +85,24 @@ def __init__( pb is therefore always 1. logF: A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states. - forward_looking: Whether to use the forward-looking GFN loss. + forward_looking: Whether to use the forward-looking GFN loss. When True, + rewards should be defined over edges; this implementation treats the + edge reward as the difference between the successor and current state + rewards, so only valid if the environment follows that assumption. constant_pb: Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. log_reward_clip_min: If finite, clips log rewards to this value. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ super().__init__( - pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min + pf, + pb, + constant_pb=constant_pb, + log_reward_clip_min=log_reward_clip_min, + debug=debug, ) # Disallow recurrent PF for transition-based DB @@ -158,7 +169,10 @@ def get_pfs_and_pbs( ) def get_scores( - self, env: Env, transitions: Transitions, recalculate_all_logprobs: bool = True + self, + env: Env, + transitions: Transitions, + recalculate_all_logprobs: bool = True, ) -> torch.Tensor: r"""Calculates the scores for a batch of transitions. @@ -174,88 +188,103 @@ def get_scores( A tensor of shape (n_transitions,) representing the scores for each transition. """ - if transitions.is_backward: + # Guard bad inputs under debug to avoid graph breaks in torch.compile. + if self.debug and transitions.is_backward: raise ValueError("Backward transitions are not supported") states = transitions.states actions = transitions.actions + conditions = ( + transitions.conditions + ) # reuse locally to avoid repeated attribute lookups + next_states = ( + transitions.next_states + ) # reuse to avoid repeated attribute lookups if len(states) == 0: return torch.tensor(0.0, device=transitions.device) - check_compatibility(states, actions, transitions) - assert ( - not transitions.states.is_sink_state.any() - ), "Transition from sink state is not allowed. This is a bug." + if self.debug: + check_compatibility(states, actions, transitions) + assert ( + not transitions.states.is_sink_state.any() + ), "Transition from sink state is not allowed. This is a bug." - ### Compute log_pf and log_pb + # Compute log_pf and log_pb log_pf, log_pb = self.get_pfs_and_pbs(transitions, recalculate_all_logprobs) - ### Compute log_F_s + # Compute log_F_s # LogF is potentially a conditional computation. - if transitions.conditions is not None: + if conditions is not None: with has_conditions_exception_handler("logF", self.logF): - log_F_s = self.logF(states, transitions.conditions).squeeze(-1) + log_F_s = self.logF(states, conditions).squeeze(-1) else: with no_conditions_exception_handler("logF", self.logF): log_F_s = self.logF(states).squeeze(-1) - ### Compute log_F_s_next - log_F_s_next = torch.zeros_like(log_F_s) + # Compute log_F_s_next + # Preallocate once and fill; write terminating rewards first to avoid an extra zero fill. + log_F_s_next = torch.empty_like(log_F_s) is_terminating = transitions.is_terminating is_intermediate = ~is_terminating - # Assign log_F_s_next for intermediate next states - interm_next_states = transitions.next_states[is_intermediate] - # log_F is potentially a conditional computation. - if transitions.conditions is not None: - with has_conditions_exception_handler("logF", self.logF): - log_F_s_next[is_intermediate] = self.logF( - interm_next_states, - transitions.conditions[is_intermediate], - ).squeeze(-1) - else: - with no_conditions_exception_handler("logF", self.logF): - log_F_s_next[is_intermediate] = self.logF(interm_next_states).squeeze(-1) + # Assign log_F_s_next for terminating transitions directly from clamped rewards. + log_rewards = transitions.log_rewards + assert log_rewards is not None + log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) + log_F_s_next[is_terminating] = log_rewards[is_terminating] + + # Assign log_F_s_next for intermediate next states (skip work if none). + if torch.any(is_intermediate): + interm_idx = is_intermediate.nonzero(as_tuple=True)[0] + interm_next_states = next_states[interm_idx] + # log_F is potentially a conditional computation. + if conditions is not None: + with has_conditions_exception_handler("logF", self.logF): + log_F_s_next[interm_idx] = self.logF( + interm_next_states, + conditions[interm_idx], + ).squeeze(-1) + else: + with no_conditions_exception_handler("logF", self.logF): + log_F_s_next[interm_idx] = self.logF(interm_next_states).squeeze(-1) # Apply forward-looking if applicable if self.forward_looking: - import warnings - - warnings.warn( - "Rewards should be defined over edges in forward-looking settings. " - "The current implementation is a special case of this, where the edge " - "reward is defined as the difference between the reward of two states " - "that the edge connects. If your environment is not the case, " - "forward-looking may be inappropriate." - ) + # Keep explanatory warning only in debug to avoid compile-time graph breaks. + if self.debug: + import warnings + + warnings.warn( + "Rewards should be defined over edges in forward-looking settings. " + "The current implementation is a special case of this, where the edge " + "reward is defined as the difference between the reward of two states " + "that the edge connects. If your environment is not the case, " + "forward-looking may be inappropriate." + ) # Reward calculation can also be conditional. - if transitions.conditions is not None: - log_rewards_state = env.log_reward(states, transitions.conditions) # type: ignore - log_rewards_next = env.log_reward( - interm_next_states, transitions.conditions[is_intermediate] # type: ignore - ) + if conditions is not None: + log_rewards_state = env.log_reward(states, conditions) # type: ignore else: log_rewards_state = env.log_reward(states) - log_rewards_next = env.log_reward(interm_next_states) - if math.isfinite(self.log_reward_clip_min): - log_rewards_state = log_rewards_state.clamp_min(self.log_reward_clip_min) - log_rewards_next = log_rewards_next.clamp_min(self.log_reward_clip_min) + log_rewards_state = log_rewards_state.clamp_min(self.log_reward_clip_min) log_F_s = log_F_s + log_rewards_state - log_F_s_next[is_intermediate] = ( - log_F_s_next[is_intermediate] + log_rewards_next - ) - # Assign log_F_s_next for terminating transitions as log_rewards - log_rewards = transitions.log_rewards - assert log_rewards is not None - if math.isfinite(self.log_reward_clip_min): - log_rewards = log_rewards.clamp_min(self.log_reward_clip_min) - log_F_s_next[is_terminating] = log_rewards[is_terminating] + if torch.any(is_intermediate): + if conditions is not None: + log_rewards_next = env.log_reward( + next_states[interm_idx], + conditions[interm_idx], # type: ignore + ) + else: + log_rewards_next = env.log_reward(next_states[interm_idx]) - ### Compute scores + log_rewards_next = log_rewards_next.clamp_min(self.log_reward_clip_min) + log_F_s_next[interm_idx] = log_F_s_next[interm_idx] + log_rewards_next + + # Compute scores preds = log_pf + log_F_s targets = log_pb + log_F_s_next scores = preds - targets @@ -279,17 +308,23 @@ def loss( transitions: The Transitions object to compute the loss with. recalculate_all_logprobs: Whether to re-evaluate all logprobs. reduction: The reduction method to use ('mean', 'sum', or 'none'). + Run with self.debug=False for improved performance. Returns: The computed detailed balance loss as a tensor. The shape depends on the reduction method. """ - warn_about_recalculating_logprobs(transitions, recalculate_all_logprobs) - scores = self.get_scores(env, transitions, recalculate_all_logprobs) + if self.debug: + warn_about_recalculating_logprobs(transitions, recalculate_all_logprobs) + scores = self.get_scores( + env, + transitions, + recalculate_all_logprobs=recalculate_all_logprobs, + ) scores = scores**2 loss = loss_reduce(scores, reduction) - if torch.isnan(loss).any(): + if self.debug and torch.isnan(loss).any(): raise ValueError("loss is nan") return loss @@ -327,6 +362,7 @@ def __init__( pf: Estimator, pb: Estimator | None, constant_pb: bool = False, + debug: bool = False, ) -> None: """Initializes a ModifiedDBGFlowNet instance. @@ -334,12 +370,15 @@ def __init__( pf: Forward policy estimator. pb: Backward policy estimator or None. constant_pb: See base class. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ - super().__init__(pf, pb, constant_pb=constant_pb) + super().__init__(pf, pb, constant_pb=constant_pb, debug=debug) def get_scores( - self, transitions: Transitions, recalculate_all_logprobs: bool = True + self, + transitions: Transitions, + recalculate_all_logprobs: bool = True, ) -> torch.Tensor: """Calculates DAG-GFN-style modified detailed balance scores. @@ -360,31 +399,41 @@ def get_scores( Returns: A tensor of shape (n_transitions,) containing the scores for each transition. """ - if transitions.is_backward: + if self.debug and transitions.is_backward: raise ValueError("Backward transitions are not supported") if len(transitions) == 0: return torch.tensor(0.0, device=transitions.device) + conditions = ( + transitions.conditions + ) # reuse locally to avoid repeated attribute lookups mask = ~transitions.next_states.is_sink_state states = transitions.states[mask] valid_next_states = transitions.next_states[mask] actions = transitions.actions[mask] all_log_rewards = transitions.all_log_rewards[mask] - check_compatibility(states, actions, transitions) + # If no non-sink next states, bail out early to avoid needless estimator calls. + if len(states) == 0: + return torch.tensor(0.0, device=transitions.device) - if transitions.conditions is not None: + if self.debug: + check_compatibility(states, actions, transitions) + + # Single pf forward for current states; reuse the resulting distribution for both + # taken-action log-probs and exit-action log-probs to avoid extra forwards. + if conditions is not None: with has_conditions_exception_handler("pf", self.pf): - module_output = self.pf(states, transitions.conditions[mask]) + pf_outputs = self.pf(states, conditions[mask]) else: with no_conditions_exception_handler("pf", self.pf): - module_output = self.pf(states) + pf_outputs = self.pf(states) if len(states) == 0: return torch.tensor(0.0, device=transitions.device) - pf_dist = self.pf.to_probability_distribution(states, module_output) + pf_dist = self.pf.to_probability_distribution(states, pf_outputs) if transitions.has_log_probs and not recalculate_all_logprobs: valid_log_pf_actions = transitions[mask].log_probs @@ -392,33 +441,36 @@ def get_scores( else: # Evaluate the log PF of the actions sampled off policy. valid_log_pf_actions = pf_dist.log_prob(actions.tensor) - valid_log_pf_s_exit = pf_dist.log_prob( - torch.full_like(actions.tensor, actions.__class__.exit_action[0].item()) - ) + # Avoid .item() in hot path to stay compile-friendly; broadcast exit_action tensor. + exit_action_tensor = actions.__class__.exit_action.to( + actions.tensor.device, dtype=actions.tensor.dtype + ).expand_as(actions.tensor) + valid_log_pf_s_exit = pf_dist.log_prob(exit_action_tensor) + + # Reuse the exit_action tensor and create the next-state distribution once; this + # avoids an additional forward or repeated log_prob calls. + if len(valid_next_states) > 0: + if conditions is not None: + with has_conditions_exception_handler("pf", self.pf): + pf_next_outputs = self.pf(valid_next_states, conditions[mask]) + else: + with no_conditions_exception_handler("pf", self.pf): + pf_next_outputs = self.pf(valid_next_states) - # The following two lines are slightly inefficient, given that most - # next_states are also states, for which we already did a forward pass. - if transitions.conditions is not None: - with has_conditions_exception_handler("pf", self.pf): - module_output = self.pf(valid_next_states, transitions.conditions[mask]) + pf_next_dist = self.pf.to_probability_distribution( + valid_next_states, pf_next_outputs + ) + valid_log_pf_s_prime_exit = pf_next_dist.log_prob(exit_action_tensor) else: - with no_conditions_exception_handler("pf", self.pf): - module_output = self.pf(valid_next_states) - - valid_log_pf_s_prime_exit = self.pf.to_probability_distribution( - valid_next_states, module_output - ).log_prob( - torch.full_like(actions.tensor, actions.__class__.exit_action[0].item()) - ) + # Should be rare because of the early return above; keep shape-friendly zero. + valid_log_pf_s_prime_exit = torch.zeros_like(valid_log_pf_s_exit) non_exit_actions = actions[~actions.is_exit] if self.pb is not None: - if transitions.conditions is not None: + if conditions is not None: with has_conditions_exception_handler("pb", self.pb): - module_output = self.pb( - valid_next_states, transitions.conditions[mask] - ) + module_output = self.pb(valid_next_states, conditions[mask]) else: with no_conditions_exception_handler("pb", self.pb): module_output = self.pb(valid_next_states) @@ -435,7 +487,7 @@ def get_scores( targets = all_log_rewards[:, 1] + valid_log_pb_actions + valid_log_pf_s_exit scores = preds - targets - if torch.any(torch.isinf(scores)): + if self.debug and torch.any(torch.isinf(scores)): raise ValueError("scores contains inf") return scores @@ -460,9 +512,11 @@ def loss( on the reduction method. """ del env - warn_about_recalculating_logprobs(transitions, recalculate_all_logprobs) + if self.debug: + warn_about_recalculating_logprobs(transitions, recalculate_all_logprobs) scores = self.get_scores( - transitions, recalculate_all_logprobs=recalculate_all_logprobs + transitions, + recalculate_all_logprobs=recalculate_all_logprobs, ) scores = scores**2 return loss_reduce(scores, reduction) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index c1a70d59..b07c2559 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -41,15 +41,18 @@ class FMGFlowNet(GFlowNet[StatesContainer[DiscreteStates]]): the default (non-recurrent) PolicyMixin interface. """ - def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): + def __init__( + self, logF: DiscretePolicyEstimator, alpha: float = 1.0, debug: bool = False + ): """Initializes a FMGFlowNet instance. Args: logF: A DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator for estimating the log flow of the edges (states -> next_states). alpha: A scalar weight for the reward matching loss. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ - super().__init__() + super().__init__(debug=debug) assert isinstance( logF, PolicyMixin ), "logF must use the default PolicyMixin interface" @@ -123,54 +126,75 @@ def flow_matching_loss( assert len(states.batch_shape) == 1 assert not torch.any(states.is_initial_state) - incoming_log_flows = torch.full_like( - states.backward_masks, -float("inf"), dtype=torch.get_default_dtype() + incoming_log_flows = torch.full( + states.backward_masks.shape, + -float("inf"), + device=states.device, + dtype=torch.get_default_dtype(), ) - outgoing_log_flows = torch.full_like( - states.forward_masks, -float("inf"), dtype=torch.get_default_dtype() + outgoing_log_flows = torch.full( + states.forward_masks.shape, + -float("inf"), + device=states.device, + dtype=torch.get_default_dtype(), ) - # TODO: Need to vectorize this loop. - for action_idx in range(env.n_actions - 1): - valid_backward_mask = states.backward_masks[:, action_idx] - valid_forward_mask = states.forward_masks[:, action_idx] - valid_backward_states = states[valid_backward_mask] - valid_forward_states = states[valid_forward_mask] - - backward_actions = torch.full_like( - valid_backward_states.backward_masks[:, 0], action_idx, dtype=torch.long - ).unsqueeze(-1) - backward_actions = env.actions_from_tensor(backward_actions) - - valid_backward_states_parents = env._backward_step( - valid_backward_states, backward_actions - ) + # Vectorized over actions. + valid_backward = states.backward_masks + backward_indices = valid_backward.nonzero(as_tuple=False) + if backward_indices.numel() > 0: + backward_state_idx = backward_indices[:, 0] # time index + backward_action_idx = backward_indices[:, 1] # action index + backward_states = states[backward_state_idx] + backward_actions_tensor = backward_action_idx.view(-1, 1) + backward_actions = env.actions_from_tensor(backward_actions_tensor) + backward_parents = env._backward_step(backward_states, backward_actions) # type: ignore + + # calculate log flows of backward actions. + if conditions is not None: + backward_conditions = conditions[backward_state_idx] + with has_conditions_exception_handler("logF", self.logF): + backward_logF = ( + self.logF(backward_parents, backward_conditions) + .gather(1, backward_action_idx.view(-1, 1)) + .squeeze(1) + ) + else: + with no_conditions_exception_handler("logF", self.logF): + backward_logF = ( + self.logF(backward_parents) + .gather(1, backward_action_idx.view(-1, 1)) + .squeeze(1) + ) + + incoming_log_flows[backward_state_idx, backward_action_idx] = backward_logF + + # Vectorized over all non-exit forward actions. + valid_forward = states.forward_masks[:, :-1] + forward_indices = valid_forward.nonzero(as_tuple=False) + if forward_indices.numel() > 0: + forward_state_idx = forward_indices[:, 0] + forward_action_idx = forward_indices[:, 1] + forward_states = states[forward_state_idx] if conditions is not None: # Mask out only valid conditions elements. - valid_backward_conditions = conditions[valid_backward_mask] - valid_forward_conditions = conditions[valid_forward_mask] - + forward_conditions = conditions[forward_state_idx] with has_conditions_exception_handler("logF", self.logF): - incoming_log_flows[valid_backward_mask, action_idx] = self.logF( - valid_backward_states_parents, - valid_backward_conditions, - )[:, action_idx] - - outgoing_log_flows[valid_forward_mask, action_idx] = self.logF( - valid_forward_states, - valid_forward_conditions, - )[:, action_idx] - + forward_logF = ( + self.logF(forward_states, forward_conditions) + .gather(1, forward_action_idx.view(-1, 1)) + .squeeze(1) + ) else: with no_conditions_exception_handler("logF", self.logF): - incoming_log_flows[valid_backward_mask, action_idx] = self.logF( - valid_backward_states_parents, - )[:, action_idx] + forward_logF = ( + self.logF(forward_states) + .gather(1, forward_action_idx.view(-1, 1)) + .squeeze(1) + ) - outgoing_log_flows[valid_forward_mask, action_idx] = self.logF( - valid_forward_states, - )[:, action_idx] + outgoing_log_flows[forward_state_idx, forward_action_idx] = forward_logF # Now the exit action. valid_forward_mask = states.forward_masks[:, -1] @@ -254,9 +278,10 @@ def loss( The computed flow matching loss as a tensor. The shape depends on the reduction method. """ - assert isinstance(states_container.intermediary_states, DiscreteStates) - assert isinstance(states_container.terminating_states, DiscreteStates) - if recalculate_all_logprobs: + if self.debug: + assert isinstance(states_container.intermediary_states, DiscreteStates) + assert isinstance(states_container.terminating_states, DiscreteStates) + if recalculate_all_logprobs and self.debug: warnings.warn( "recalculate_all_logprobs is not used for FM. Ignoring the argument." ) diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 37296c61..c82321cd 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -83,6 +83,7 @@ def __init__( log_reward_clip_min: float = -float("inf"), forward_looking: bool = False, constant_pb: bool = False, + debug: bool = False, ): """Initializes a SubTBGFlowNet instance. @@ -100,10 +101,15 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ super().__init__( - pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min + pf, + pb, + constant_pb=constant_pb, + log_reward_clip_min=log_reward_clip_min, + debug=debug, ) assert any( isinstance(logF, cls) @@ -371,13 +377,14 @@ def get_scores( ).unsqueeze(-1) ) - flat_preds = preds[~flattening_mask] - if torch.any(torch.isnan(flat_preds)): - raise ValueError("NaN in preds") + if self.debug: + flat_preds = preds[~flattening_mask] + if torch.any(torch.isnan(flat_preds)): + raise ValueError("NaN in preds") - flat_targets = targets[~flattening_mask] - if torch.any(torch.isnan(flat_targets)): - raise ValueError("NaN in targets") + flat_targets = targets[~flattening_mask] + if torch.any(torch.isnan(flat_targets)): + raise ValueError("NaN in targets") flattening_masks.append(flattening_mask) scores.append(preds - targets) @@ -492,15 +499,17 @@ def get_geometric_within_contributions( Returns: The contributions tensor of shape (max_len * (max_len+1) / 2, n_trajectories). """ - L = self.lamda - max_len = trajectories.max_length t_idx = trajectories.terminating_idx + default_dtype = torch.get_default_dtype() + lam = torch.as_tensor(self.lamda, device=t_idx.device, dtype=default_dtype) + max_len = trajectories.max_length + + arange = torch.arange(max_len, device=t_idx.device, dtype=default_dtype) + pow_lam = lam.pow(arange) # The following tensor represents the weights given to each possible # sub-trajectory length. - contributions = (L ** torch.arange(max_len, device=t_idx.device).double()).to( - torch.get_default_dtype() - ) + contributions = pow_lam contributions = contributions.unsqueeze(-1).repeat(1, len(trajectories)) contributions = contributions.repeat_interleave( torch.arange(max_len, 0, -1, device=t_idx.device), @@ -508,15 +517,18 @@ def get_geometric_within_contributions( output_size=int(max_len * (max_len + 1) / 2), ) - # Now we need to divide each column by n + (n-1) lambda +...+ 1*lambda^{n-1} - # where n is the length of the trajectory corresponding to that column + # Now we need to divide each column by + # sum_{i=0}^{n-1} (n - i) * lambda^i + # where n is the length of the trajectory corresponding to that column. # We can do it the ugly way, or using the cool identity: # https://www.wolframalpha.com/input?i=sum%28%28n-i%29+*+lambda+%5Ei%2C+i%3D0..n%29 - per_trajectory_denom = ( - 1.0 - / (1 - L) ** 2 - * (L * (L ** t_idx.double() - 1) + (1 - L) * t_idx.double()) - ).to(torch.get_default_dtype()) + pow_cumsum = pow_lam.cumsum(0) + i_pow_cumsum = (arange * pow_lam).cumsum(0) + gather_idx = (t_idx.clamp(min=1) - 1).long() + sum_geom = pow_cumsum[gather_idx] + sum_i_geom = i_pow_cumsum[gather_idx] + t_idx_f = t_idx.to(default_dtype) + per_trajectory_denom = t_idx_f * sum_geom - sum_i_geom contributions = contributions / per_trajectory_denom / len(trajectories) return contributions @@ -535,12 +547,16 @@ def loss( trajectories: The batch of trajectories to compute the loss with. recalculate_all_logprobs: Whether to re-evaluate all logprobs. reduction: The reduction method to use ('mean', 'sum', or 'none'). + Note: for geometric-based sub-trajectory weighting, 'mean' is not + supported and is coerced to 'sum' (a warning is emitted when + debug=True). Returns: The computed sub-trajectory balance loss as a tensor. The shape depends on the reduction method. """ - warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) + if self.debug: + warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) # Get all scores and masks from the trajectories. scores, flattening_masks = self.get_scores( env, trajectories, recalculate_all_logprobs=recalculate_all_logprobs @@ -576,7 +592,8 @@ def loss( weights = ratio * ( L ** torch.arange(max_len, device=per_length_losses.device) ) - assert (weights.sum() - 1.0).abs() < 1e-5, f"{weights.sum()}" + if self.debug: + assert (weights.sum() - 1.0).abs() < 1e-5, f"{weights.sum()}" return (per_length_losses * weights).sum() # TODO: we need to know what reductions are valid for each weighting method. @@ -593,17 +610,19 @@ def loss( raise ValueError(f"Unknown weighting method {self.weighting}") flat_contributions = contributions[~flattening_mask] - assert ( - flat_contributions.sum() - 1.0 - ).abs() < 1e-5, f"{flat_contributions.sum()}" + if self.debug: + assert ( + flat_contributions.sum() - 1.0 + ).abs() < 1e-5, f"{flat_contributions.sum()}" final_scores = flat_contributions * all_scores[~flattening_mask].pow(2) # TODO: default was sum, should we allow mean? if reduction == "mean": - warnings.warn( - "Mean reduction is not supported for SubTBGFlowNet with geometric " - "weighting, using sum instead." - ) + if self.debug: + warnings.warn( + "Mean reduction is not supported for SubTBGFlowNet with geometric " + "weighting, using sum instead." + ) reduction = "sum" return loss_reduce(final_scores, reduction) diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 8e2d6e6b..3283875b 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -49,6 +49,7 @@ def __init__( init_logZ: float = 0.0, constant_pb: bool = False, log_reward_clip_min: float = -float("inf"), + debug: bool = False, ): """Initializes a TBGFlowNet instance. @@ -63,9 +64,14 @@ def __init__( is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. log_reward_clip_min: If finite, clips log rewards to this value. + debug: If True, keep runtime safety checks active; disable in compiled runs. """ super().__init__( - pf, pb, constant_pb=constant_pb, log_reward_clip_min=log_reward_clip_min + pf, + pb, + constant_pb=constant_pb, + log_reward_clip_min=log_reward_clip_min, + debug=debug, ) self.logZ = logZ or nn.Parameter(torch.tensor(init_logZ)) @@ -109,7 +115,8 @@ def loss( reduction method. """ del env # unused - warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) + if self.debug: + warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) scores = self.get_scores( trajectories, recalculate_all_logprobs=recalculate_all_logprobs ) @@ -119,14 +126,14 @@ def loss( if trajectories.conditions is not None: with is_callable_exception_handler("logZ", self.logZ): assert isinstance(self.logZ, ScalarEstimator) - logZ = self.logZ(trajectories.conditions) + logZ = self.logZ(trajectories.conditions) # [N] or [..., 1] else: - logZ = self.logZ + logZ = self.logZ # [] - logZ = cast(torch.Tensor, logZ) - scores = (scores + logZ.squeeze()).pow(2) + logZ = cast(torch.Tensor, logZ).squeeze() # [] or [N] + scores = torch.square(scores + logZ) # [N] loss = loss_reduce(scores, reduction) - if torch.isnan(loss).any(): + if self.debug and torch.isnan(loss).any(): raise ValueError("loss is nan") return loss @@ -170,13 +177,15 @@ def loss( the reduction method. """ del env # unused - warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) + if self.debug: + warn_about_recalculating_logprobs(trajectories, recalculate_all_logprobs) scores = self.get_scores( trajectories, recalculate_all_logprobs=recalculate_all_logprobs ) - scores = (scores - scores.mean()).pow(2) + scores = scores.sub_(scores.mean()) # [N], in-place mean-centering. + scores = torch.square(scores) # [N] loss = loss_reduce(scores, reduction) - if torch.isnan(loss).any(): + if self.debug and torch.isnan(loss).any(): raise ValueError("loss is NaN.") return loss diff --git a/src/gfn/gym/bitSequence.py b/src/gfn/gym/bitSequence.py index ed0ecb95..fadcc9a5 100644 --- a/src/gfn/gym/bitSequence.py +++ b/src/gfn/gym/bitSequence.py @@ -449,10 +449,6 @@ def make_modes_set(self, seed) -> torch.Tensor: ValueError: If the number of requested modes exceeds the number of possible unique sequences. """ - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - if self.H is None: self.H = torch.tensor( [ @@ -472,11 +468,16 @@ def make_modes_set(self, seed) -> torch.Tensor: "Not enough unique sequences available for the set of modes." ) + g = torch.Generator(device=self.device) + g.manual_seed(seed) + unique_indices = set() unique_sequences = [] while len(unique_sequences) < self.n_modes: - candidate = tuple(torch.randint(0, self.H.shape[0], (m,)).tolist()) + candidate = tuple( + torch.randint(0, self.H.shape[0], (m,), generator=g).tolist() + ) if candidate not in unique_indices: unique_indices.add(candidate) unique_sequences.append(candidate) diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 1047dc98..f0db5b38 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -178,19 +178,19 @@ def get_trajectory_pfs( else: # Vectorized path. - log_pf_trajectories = torch.full_like( - trajectories.actions.tensor[..., 0], + # Allocate log_pf explicitly as floating point to avoid silent int casts. + log_pf_trajectories = torch.full( + trajectories.actions.tensor[..., 0].shape, fill_value=fill_value, dtype=torch.get_default_dtype(), + device=trajectories.states.device, ) if len(valid_states) == 0: return log_pf_trajectories # Build conditions per-step shape to align with valid_states - masked_cond = None cond = trajectories.conditions - if cond is not None: T = trajectories.states.tensor.shape[0] # If conditions already has time dim (T, N, ...), index directly. @@ -199,6 +199,8 @@ def get_trajectory_pfs( else: # Broadcast (N, ...) to (T, N, ...), then index. masked_cond = cond.unsqueeze(0).expand((T,) + cond.shape)[state_mask] + else: + masked_cond = None # avoids building an expanded view when unused ctx_v = policy_pf.init_context( int(len(valid_states)), @@ -267,10 +269,12 @@ def get_trajectory_pbs( if trajectories.is_backward: raise ValueError("Backward trajectories are not supported") - log_pb_trajectories = torch.full_like( - trajectories.actions.tensor[..., 0], + # Allocate log_pb explicitly as floating point to avoid silent int casts. + log_pb_trajectories = torch.full( + trajectories.actions.tensor[..., 0].shape, fill_value=fill_value, - dtype=torch.get_default_dtype(), # Floating point dtype. + dtype=torch.get_default_dtype(), + device=trajectories.states.device, ) # Note the different mask for valid states and actions compared to the pf case. @@ -292,7 +296,6 @@ def get_trajectory_pbs( # Using all non-initial states, calculate the backward policy, and the logprobs # of those actions. - masked_cond = None cond = trajectories.conditions if cond is not None: T = trajectories.states.tensor.shape[0] @@ -300,12 +303,16 @@ def get_trajectory_pbs( masked_cond = cond[state_mask] else: masked_cond = cond.unsqueeze(0).expand((T,) + cond.shape)[state_mask] + else: + masked_cond = None # skip broadcasting when no conditions supplied # There is no backward policy in this case. if pb is None: # If pb is None, we assume that the gflownet DAG is a tree, and therefore # the backward policy probability is always 1 (log probs are 0). - valid_log_pb_actions = torch.zeros_like(valid_actions.tensor) + valid_log_pb_actions = torch.zeros_like( + valid_actions.tensor, dtype=torch.get_default_dtype() + ) valid_log_pb_actions = valid_log_pb_actions.squeeze(-1) # no padding. log_pb_trajectories[action_mask] = valid_log_pb_actions.to( log_pb_trajectories.dtype, copy=False diff --git a/testing/test_gflownet.py b/testing/test_gflownet.py index 953d217f..135c446e 100644 --- a/testing/test_gflownet.py +++ b/testing/test_gflownet.py @@ -1,11 +1,21 @@ +import pytest +import torch + from gfn.containers import StatesContainer, Trajectories from gfn.containers.base import Container -from gfn.estimators import DiscretePolicyEstimator +from gfn.estimators import DiscretePolicyEstimator, ScalarEstimator from gfn.gflownet import FMGFlowNet, TBGFlowNet +from gfn.gflownet.base import loss_reduce +from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet from gfn.gym import Box, HyperGrid from gfn.gym.helpers.box_utils import BoxPBEstimator, BoxPBMLP, BoxPFEstimator, BoxPFMLP from gfn.preprocessors import KHotPreprocessor +from gfn.samplers import Sampler from gfn.states import DiscreteStates +from gfn.utils.handlers import ( + has_conditions_exception_handler, + no_conditions_exception_handler, +) from gfn.utils.modules import MLP @@ -86,3 +96,265 @@ def test_pytorch_inheritance(): assert hasattr( fmgflownet.state_dict(), "__dict__" ), "Expected gflownet to have indexable state_dict() method inherited from nn.Module" + + +@pytest.mark.parametrize("seed", [0, 12, 47, 67]) +def test_flow_matching_vectorized_matches_original(seed): + torch.manual_seed(seed) + env = HyperGrid(ndim=2) + preprocessor = KHotPreprocessor(ndim=env.ndim, height=env.height) + module = MLP(input_dim=preprocessor.output_dim, output_dim=env.n_actions) + estimator = DiscretePolicyEstimator( + module, n_actions=env.n_actions, preprocessor=preprocessor + ) + gflownet = FMGFlowNet(estimator) + + trajectories = gflownet.sample_trajectories(env, n=6) + states_container = gflownet.to_training_samples(trajectories) + states = states_container.intermediary_states + conditions = states_container.intermediary_conditions + + if len(states) == 0: + # If the sample produced only terminal states, resample with more trajectories. + trajectories = gflownet.sample_trajectories(env, n=12) + states_container = gflownet.to_training_samples(trajectories) + states = states_container.intermediary_states + conditions = states_container.intermediary_conditions + + assert len(states) > 0 + + def flow_matching_loss_original( + self, env, states, conditions, reduction: str = "mean" + ): + if len(states) == 0: + return torch.tensor(0.0, device=states.device) + assert len(states.batch_shape) == 1 + assert not torch.any(states.is_initial_state) + incoming_log_flows = torch.full_like( + states.backward_masks, -float("inf"), dtype=torch.get_default_dtype() + ) + outgoing_log_flows = torch.full_like( + states.forward_masks, -float("inf"), dtype=torch.get_default_dtype() + ) + for action_idx in range(env.n_actions - 1): + valid_backward_mask = states.backward_masks[:, action_idx] + valid_forward_mask = states.forward_masks[:, action_idx] + valid_backward_states = states[valid_backward_mask] + valid_forward_states = states[valid_forward_mask] + backward_actions = torch.full_like( + valid_backward_states.backward_masks[:, 0], action_idx, dtype=torch.long + ).unsqueeze(-1) + backward_actions = env.actions_from_tensor(backward_actions) + valid_backward_states_parents = env._backward_step( + valid_backward_states, backward_actions + ) + if conditions is not None: + valid_backward_conditions = conditions[valid_backward_mask] + valid_forward_conditions = conditions[valid_forward_mask] + with has_conditions_exception_handler("logF", self.logF): + incoming_log_flows[valid_backward_mask, action_idx] = self.logF( + valid_backward_states_parents, + valid_backward_conditions, + )[:, action_idx] + outgoing_log_flows[valid_forward_mask, action_idx] = self.logF( + valid_forward_states, + valid_forward_conditions, + )[:, action_idx] + else: + with no_conditions_exception_handler("logF", self.logF): + incoming_log_flows[valid_backward_mask, action_idx] = self.logF( + valid_backward_states_parents, + )[:, action_idx] + outgoing_log_flows[valid_forward_mask, action_idx] = self.logF( + valid_forward_states, + )[:, action_idx] + valid_forward_mask = states.forward_masks[:, -1] + if conditions is not None: + with has_conditions_exception_handler("logF", self.logF): + outgoing_log_flows[valid_forward_mask, -1] = self.logF( + states[valid_forward_mask], + conditions[valid_forward_mask], + )[:, -1] + else: + with no_conditions_exception_handler("logF", self.logF): + outgoing_log_flows[valid_forward_mask, -1] = self.logF( + states[valid_forward_mask], + )[:, -1] + log_incoming_flows = torch.logsumexp(incoming_log_flows, dim=-1) + log_outgoing_flows = torch.logsumexp(outgoing_log_flows, dim=-1) + scores = (log_incoming_flows - log_outgoing_flows).pow(2) + return loss_reduce(scores, reduction) + + loss_original = flow_matching_loss_original( + gflownet, env, states, conditions, reduction="mean" + ) + loss_vectorized = gflownet.flow_matching_loss( + env, states, conditions, reduction="mean" + ) + + torch.testing.assert_close(loss_vectorized, loss_original) + + +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_subtb_get_scores_vectorized_matches_original(seed: int): + torch.manual_seed(seed) + n_traj = 4 + + # Deterministic HyperGrid env and frozen estimators so real methods can run. + env = HyperGrid(ndim=2, height=3, device="cpu", debug=False) + preproc = KHotPreprocessor(height=env.height, ndim=env.ndim) + + # Tiny MLPs with random weights (frozen for determinism). + module_pf = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions) + module_pb = MLP(input_dim=preproc.output_dim, output_dim=env.n_actions - 1) + module_logF = MLP(input_dim=preproc.output_dim, output_dim=1) + for mod in (module_pf, module_pb, module_logF): + for p in mod.parameters(): + p.requires_grad_(False) + + pf = DiscretePolicyEstimator( + module=module_pf, + n_actions=env.n_actions, + preprocessor=preproc, + is_backward=False, + ) + pb = DiscretePolicyEstimator( + module=module_pb, n_actions=env.n_actions, preprocessor=preproc, is_backward=True + ) + logF = ScalarEstimator(module=module_logF, preprocessor=preproc) + + # Initialize model via __init__ to set up real methods. + model = SubTBGFlowNet( + pf=pf, pb=pb, logF=logF, weighting="geometric_within", lamda=0.9 + ) + model.debug = False + model.log_reward_clip_min = float("-inf") + model.eval() + pf.eval() + pb.eval() + logF.eval() + + # Sample a deterministic batch of trajectories with frozen estimators. + sampler = Sampler(estimator=pf) + trajectories: Trajectories = sampler.sample_trajectories( + env, + n=n_traj, + epsilon=0.0, + save_logprobs=True, + save_estimator_outputs=False, + ) + max_len = trajectories.max_length # noqa: F841 used implicitly by shapes + + def original_get_scores(self, env, trajectories, recalculate_all_logprobs=True): + log_pf_trajectories_, log_pb_trajectories_ = self.get_pfs_and_pbs( + trajectories, recalculate_all_logprobs=recalculate_all_logprobs + ) + + log_pf_trajectories_cum = self.cumulative_logprobs( + trajectories, log_pf_trajectories_ + ) + log_pb_trajectories_cum = self.cumulative_logprobs( + trajectories, log_pb_trajectories_ + ) + + log_state_flows_ = self.calculate_log_state_flows( + env, trajectories, log_pf_trajectories_ + ) + sink_states_mask_, is_terminal_mask_ = self.calculate_masks( + log_state_flows_, trajectories + ) + + flattening_masks_orig = [] + scores_orig = [] + for i in range(1, 1 + trajectories.max_length): + preds = self.calculate_preds(log_pf_trajectories_cum, log_state_flows_, i) + targets = self.calculate_targets( + trajectories, + preds, + log_pb_trajectories_cum, + log_state_flows_, + is_terminal_mask_, + sink_states_mask_, + i, + ) + + flattening_mask = trajectories.terminating_idx.lt( + torch.arange( + i, + trajectories.max_length + 1, + device=trajectories.terminating_idx.device, + ).unsqueeze(-1) + ) + + flat_preds = preds[~flattening_mask] + if self.debug and torch.any(torch.isnan(flat_preds)): + raise ValueError("NaN in preds") + + flat_targets = targets[~flattening_mask] + if self.debug and torch.any(torch.isnan(flat_targets)): + raise ValueError("NaN in targets") + + flattening_masks_orig.append(flattening_mask) + scores_orig.append(preds - targets) + + return scores_orig, flattening_masks_orig + + def normalize_scores_masks( + scores, masks, trajectories: Trajectories + ) -> tuple[torch.Tensor, torch.Tensor]: + """Convert list outputs to padded tensors; pass tensors through unchanged.""" + if isinstance(scores, torch.Tensor): + assert isinstance(masks, torch.Tensor) + return scores, masks + + assert isinstance(scores, (list, tuple)) + assert isinstance(masks, (list, tuple)) + + max_len = trajectories.max_length + n_traj = ( + trajectories.n_trajectories + if hasattr(trajectories, "n_trajectories") + else len(trajectories) + ) + device = trajectories.terminating_idx.device + dtype = scores[0].dtype + + scores_padded = torch.zeros( + (max_len, max_len, n_traj), dtype=dtype, device=device + ) + masks_padded = torch.ones( + (max_len, max_len, n_traj), dtype=torch.bool, device=device + ) + + for i, (s, m) in enumerate(zip(scores, masks), start=1): + seq_len = s.shape[0] + scores_padded[i - 1, :seq_len] = s + masks_padded[i - 1, :seq_len] = m + + return scores_padded, masks_padded + + # Recompute logprobs to ensure PF/PB are evaluated for both paths. + orig_scores_list, orig_masks_list = original_get_scores( + model, env, trajectories, recalculate_all_logprobs=True + ) + vec_scores, vec_masks = model.get_scores( + env, trajectories, recalculate_all_logprobs=True + ) # type: ignore + + vec_scores_t, vec_masks_t = normalize_scores_masks( + vec_scores, vec_masks, trajectories + ) + orig_scores_t, orig_masks_t = normalize_scores_masks( + orig_scores_list, orig_masks_list, trajectories + ) + + valid_mask = ~orig_masks_t + if not torch.allclose( + vec_scores_t[valid_mask], orig_scores_t[valid_mask], equal_nan=True + ): + max_diff = (vec_scores_t[valid_mask] - orig_scores_t[valid_mask]).abs().max() + raise AssertionError( + f"Score mismatch on valid positions; max_abs_diff={max_diff.item()}" + ) + + torch.testing.assert_close(vec_masks_t, orig_masks_t, equal_nan=True) diff --git a/testing/test_subtb_contributions.py b/testing/test_subtb_contributions.py new file mode 100644 index 00000000..de137c32 --- /dev/null +++ b/testing/test_subtb_contributions.py @@ -0,0 +1,109 @@ +from typing import cast + +import pytest +import torch + +from gfn.containers import Trajectories +from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet + + +class DummyTrajectories: + def __init__(self, terminating_idx: torch.Tensor): + self.terminating_idx = terminating_idx + self.max_length = int(torch.max(terminating_idx).item()) + + def __len__(self) -> int: + return self.terminating_idx.numel() + + +def _reference_contributions( + lam: float | torch.Tensor, + terminating_idx: torch.Tensor, + max_len: int, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + lam64 = torch.as_tensor(lam, dtype=torch.float64, device=device) + t_idx64 = terminating_idx.to(dtype=torch.float64) + + base = lam64.pow(torch.arange(max_len, device=device, dtype=torch.float64)) + base = base.unsqueeze(-1).repeat(1, len(terminating_idx)) + base = base.repeat_interleave( + torch.arange(max_len, 0, -1, device=device), + dim=0, + output_size=int(max_len * (max_len + 1) / 2), + ) + + denom = [] + for n in t_idx64.long().tolist(): + ar = torch.arange(n, device=device, dtype=torch.float64) + denom.append(((n - ar) * lam64.pow(ar)).sum()) + denom = torch.stack(denom) + + return (base / denom / len(terminating_idx)).to(dtype) + + +def test_geometric_within_contributions_matches_reference(): + terminating_idx = torch.tensor([1, 2, 3, 5], dtype=torch.int64) + trajectories = DummyTrajectories(terminating_idx) + + model = object.__new__(SubTBGFlowNet) + model.lamda = 0.9 + + result = model.get_geometric_within_contributions( + cast(Trajectories, trajectories) # type: ignore[arg-type] + ) + reference = _reference_contributions( + lam=model.lamda, + terminating_idx=terminating_idx, + max_len=trajectories.max_length, + dtype=result.dtype, + device=terminating_idx.device, + ) + + torch.testing.assert_close(result, reference, rtol=1e-6, atol=1e-6) + + +def test_geometric_within_contributions_near_one_is_stable(): + terminating_idx = torch.tensor([2, 4, 6], dtype=torch.int64) + trajectories = DummyTrajectories(terminating_idx) + + model_near_one = object.__new__(SubTBGFlowNet) + model_near_one.lamda = 1.0 - 1e-4 + + model_exact_one = object.__new__(SubTBGFlowNet) + model_exact_one.lamda = 1.0 + + result_near_one = model_near_one.get_geometric_within_contributions( + cast(Trajectories, trajectories) # type: ignore[arg-type] + ) + result_exact_one = model_exact_one.get_geometric_within_contributions( + cast(Trajectories, trajectories) # type: ignore[arg-type] + ) + + assert torch.isfinite(result_near_one).all() + assert torch.isfinite(result_exact_one).all() + torch.testing.assert_close(result_near_one, result_exact_one, rtol=5e-3, atol=5e-5) + + +@pytest.mark.parametrize("lam", [0.0, 0.3, 0.7, 0.9, 0.999, 1.0]) +def test_geometric_within_contributions_matches_bruteforce(lam: float): + torch.manual_seed(0) + terminating_idx = torch.randint(1, 7, size=(6,), dtype=torch.int64) + trajectories = DummyTrajectories(terminating_idx) + + model = object.__new__(SubTBGFlowNet) + model.lamda = lam + + result = model.get_geometric_within_contributions( + cast(Trajectories, trajectories) # type: ignore[arg-type] + ) + reference = _reference_contributions( + lam=model.lamda, + terminating_idx=terminating_idx, + max_len=trajectories.max_length, + dtype=result.dtype, + device=terminating_idx.device, + ) + + torch.testing.assert_close(result, reference, rtol=1e-6, atol=1e-6) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index cacd2637..e4bcbf55 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -211,7 +211,7 @@ class BoxArgs(CommonArgs): @dataclass class BitSequenceArgs(CommonArgs): - n_iterations: int = 5000 + n_iterations: int = 1000 word_size: int = 1 seq_size: int = 4 n_modes: int = 2 @@ -220,6 +220,7 @@ class BitSequenceArgs(CommonArgs): lr_Z: 1e-2 seed: int = 0 batch_size: int = 32 + deterministic_mode: bool = True @dataclass @@ -657,7 +658,9 @@ def test_bitsequence(seq_size: int, n_modes: int): args = BitSequenceArgs(seq_size=seq_size, n_modes=n_modes) final_l1_dist = train_bitsequence_main(args) assert final_l1_dist is not None - # print(f"[DEBUG] BitSequence seq_size={seq_size}, n_modes={n_modes}, l1={final_l1_dist}") + print( + f"[DEBUG] BitSequence seq_size={seq_size}, n_modes={n_modes}, l1={final_l1_dist}" + ) if seq_size == 4 and n_modes == 2: assert final_l1_dist <= 9e-5 if seq_size == 4 and n_modes == 4: diff --git a/tutorials/examples/train_bit_sequences.py b/tutorials/examples/train_bit_sequences.py index 12ad27eb..287ed82b 100644 --- a/tutorials/examples/train_bit_sequences.py +++ b/tutorials/examples/train_bit_sequences.py @@ -26,14 +26,21 @@ def estimated_dist(gflownet: PFBasedGFlowNet, env: BitSequence): def main(args): seed = args.seed if args.seed != 0 else DEFAULT_SEED - set_seed(seed) + set_seed(seed, deterministic_mode=args.deterministic_mode) device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" ) H = torch.randint( 0, 2, (args.n_modes, args.seq_size), dtype=torch.long, device=device ) - env = BitSequence(args.word_size, args.seq_size, args.n_modes, H=H, debug=__debug__) + env = BitSequence( + args.word_size, + args.seq_size, + args.n_modes, + H=H, + seed=seed, + debug=__debug__, + ) if args.loss == "TB": pf = MLP(env.words_per_seq, env.n_actions) @@ -111,6 +118,12 @@ def main(args): parser.add_argument("--loss", type=str, default="TB", help="Loss to use") parser.add_argument("--no_cuda", type=bool, default=True, help="Device to use") parser.add_argument("--seed", type=int, default=0, help="Seed") + parser.add_argument( + "--deterministic_mode", + action="store_true", + default=False, + help="Use deterministic algorithms/threads where possible", + ) parser.add_argument( "--n_iterations", type=int, default=1000, help="Number of iterations" ) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 6f3db577..193e4cce 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -58,10 +58,18 @@ from gfn.utils.common import Timer, set_seed from gfn.utils.distributed import DistributedContext, initialize_distributed_compute from gfn.utils.modules import MLP, DiscreteUniform, Tabular -from tutorials.examples.multinode.spawn_policy import ( - AsyncSelectiveAveragingPolicy, - AverageAllPolicy, -) + +# Allow running both as a package module and as a standalone script +try: + from tutorials.examples.multinode.spawn_policy import ( + AsyncSelectiveAveragingPolicy, + AverageAllPolicy, + ) +except ImportError: + from multinode.spawn_policy import ( + AsyncSelectiveAveragingPolicy, + AverageAllPolicy, + ) class ModesReplayBufferManager(ReplayBufferManager): @@ -1326,13 +1334,13 @@ def cleanup(): parser.add_argument( "--wandb_project", type=str, - default="torchgfn", + default="", help="Name of the wandb project. If empty, don't use wandb", ) parser.add_argument( "--wandb_entity", type=str, - default="torchgfn", + default="", help="Name of the wandb entity. If empty, don't use wandb", ) parser.add_argument( diff --git a/tutorials/misc/bench_get_scores_all.py b/tutorials/misc/bench_get_scores_all.py new file mode 100644 index 00000000..950e195d --- /dev/null +++ b/tutorials/misc/bench_get_scores_all.py @@ -0,0 +1,1151 @@ +""" +Unified micro-benchmark runner for GFlowNet losses/get_scores. + +Runs, in order: +- Trajectory Balance (TB) loss +- Log Partition Variance (LPV) loss +- Sub-trajectory Balance (SubTB) get_scores +- Detailed Balance (DB) get_scores +- Modified Detailed Balance (ModDB) get_scores + +For each loss, the script reports four timings: + original (baseline, frozen copy) + original+compile (torch.compile applied to the baseline function) + current (eager) + current+compile (torch.compile applied to the current function) +Two speedups are printed: current/original and current+compile/original. + +All benchmarks use embedded base sizes, scaled by a single --size-scale +multiplier. Correctness is checked once per size before timing and is +excluded from the timing loops. +""" + +from __future__ import annotations + +import argparse +import math +from types import MethodType +from typing import Any, Callable, Iterable, Tuple + +import torch +import torch.nn as nn +from torch.utils import benchmark + +from gfn.gflownet.detailed_balance import DBGFlowNet, ModifiedDBGFlowNet +from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet +from gfn.gflownet.trajectory_balance import ( + LogPartitionVarianceGFlowNet, + TBGFlowNet, +) +from gfn.utils.handlers import ( + has_conditions_exception_handler, + no_conditions_exception_handler, +) + +# ------------------------- +# TB / LPV (loss benchmark) +# ------------------------- + + +class _TBTrajectoriesStub: + def __init__( + self, log_rewards: torch.Tensor, conditions: torch.Tensor | None = None + ): + self._log_rewards = log_rewards + self.n_trajectories = log_rewards.shape[0] + self.conditions = conditions + + @property + def log_rewards(self) -> torch.Tensor: + return self._log_rewards + + +def _build_tb( + T: int, N: int, device: torch.device, dtype: torch.dtype +) -> Tuple[TBGFlowNet, _TBTrajectoriesStub, torch.Tensor, torch.Tensor]: + log_pf = torch.randn(T, N, device=device, dtype=dtype) + log_pb = torch.randn(T, N, device=device, dtype=dtype) + log_rewards = torch.randn(N, device=device, dtype=dtype) + trajectories = _TBTrajectoriesStub(log_rewards) + + model = TBGFlowNet.__new__(TBGFlowNet) + nn.Module.__init__(model) + model.debug = False + model.log_reward_clip_min = -float("inf") + model.logZ = nn.Parameter(torch.tensor(0.0, device=device, dtype=dtype)) + + def _get_pfs_and_pbs(self, _trajectories, recalculate_all_logprobs: bool = True): + return log_pf, log_pb + + model.get_pfs_and_pbs = MethodType(_get_pfs_and_pbs, model) + return model, trajectories, log_pf, log_pb + + +def _build_lpv( + T: int, N: int, device: torch.device, dtype: torch.dtype +) -> Tuple[ + LogPartitionVarianceGFlowNet, _TBTrajectoriesStub, torch.Tensor, torch.Tensor +]: + log_pf = torch.randn(T, N, device=device, dtype=dtype) + log_pb = torch.randn(T, N, device=device, dtype=dtype) + log_rewards = torch.randn(N, device=device, dtype=dtype) + trajectories = _TBTrajectoriesStub(log_rewards) + + model = LogPartitionVarianceGFlowNet.__new__(LogPartitionVarianceGFlowNet) + nn.Module.__init__(model) + model.debug = False + model.log_reward_clip_min = -float("inf") + + def _get_pfs_and_pbs(self, _trajectories, recalculate_all_logprobs: bool = True): + return log_pf, log_pb + + model.get_pfs_and_pbs = MethodType(_get_pfs_and_pbs, model) + return model, trajectories, log_pf, log_pb + + +def _tb_original_get_scores( + model, trajectories, log_pf: torch.Tensor, log_pb: torch.Tensor +): + total_log_pf_trajectories = log_pf.sum(dim=0) + total_log_pb_trajectories = log_pb.sum(dim=0) + + log_rewards = trajectories.log_rewards + if math.isfinite(model.log_reward_clip_min): + log_rewards = log_rewards.clamp_min(model.log_reward_clip_min) + + return total_log_pf_trajectories - total_log_pb_trajectories - log_rewards + + +def _tb_original_loss(model, trajectories, log_pf: torch.Tensor, log_pb: torch.Tensor): + scores = _tb_original_get_scores(model, trajectories, log_pf, log_pb) + logZ = torch.as_tensor(model.logZ).squeeze() + scores = (scores + logZ).pow(2) + return scores.mean() + + +def _lpv_original_loss(model, trajectories, log_pf: torch.Tensor, log_pb: torch.Tensor): + scores = _tb_original_get_scores(model, trajectories, log_pf, log_pb) + centered = scores - scores.mean() + return centered.pow(2).mean() + + +# ------------------------- +# SubTB (get_scores) +# ------------------------- + + +class _SubTBDummyTrajectories: + def __init__(self, terminating_idx: torch.Tensor, max_length: int): + self.terminating_idx = terminating_idx + self.max_length = max_length + self.n_trajectories = terminating_idx.shape[0] + + def __len__(self) -> int: + return self.n_trajectories + + +def _subtb_build_model_and_data( + max_len: int, n_traj: int, seed: int = 0, device: str | torch.device = "cpu" +) -> Tuple[ + SubTBGFlowNet, _SubTBDummyTrajectories, list[torch.Tensor], list[torch.Tensor] +]: + torch.manual_seed(seed) + device = torch.device(device) + terminating_idx = torch.randint(1, max_len + 1, (n_traj,), device=device) + log_rewards = torch.randn(n_traj, device=device) + log_pf_trajectories = torch.randn(max_len, n_traj, device=device) + log_pb_trajectories = torch.randn(max_len, n_traj, device=device) + log_state_flows = torch.randn(max_len, n_traj, device=device) + sink_states_mask = torch.zeros(max_len, n_traj, dtype=torch.bool, device=device) + is_terminal_mask = torch.zeros(max_len, n_traj, dtype=torch.bool, device=device) + + preds_list = [ + torch.randn(max_len + 1 - i, n_traj, device=device) + for i in range(1, max_len + 1) + ] + targets_list = [ + torch.randn(max_len + 1 - i, n_traj, device=device) + for i in range(1, max_len + 1) + ] + + trajectories = _SubTBDummyTrajectories( + terminating_idx=terminating_idx, max_length=max_len + ) + + model = SubTBGFlowNet.__new__(SubTBGFlowNet) + torch.nn.Module.__init__(model) + model.debug = False + model.log_reward_clip_min = float("-inf") + + model.get_pfs_and_pbs = MethodType( + lambda self, traj, recalculate_all_logprobs=True: ( + log_pf_trajectories, + log_pb_trajectories, + ), + model, + ) + model.calculate_log_state_flows = MethodType( + lambda self, _env, _traj, _log_pf: log_state_flows, model + ) + model.calculate_masks = MethodType( + lambda self, _log_state_flows, _traj: (sink_states_mask, is_terminal_mask), + model, + ) + trajectories.log_rewards = log_rewards + model.calculate_preds = MethodType( + lambda self, _log_pf_cum, _log_state_flows, i: preds_list[i - 1], model + ) + model.calculate_targets = MethodType( + lambda self, _traj, _preds, _log_pb_cum, _log_state_flows, _term_mask, _sink_mask, i: targets_list[ + i - 1 + ], + model, + ) + + return model, trajectories, preds_list, targets_list + + +def _subtb_original_get_scores( + model: SubTBGFlowNet, env, trajectories +) -> Tuple[list[torch.Tensor], list[torch.Tensor]]: + log_pf_trajectories_, log_pb_trajectories_ = model.get_pfs_and_pbs( + trajectories, recalculate_all_logprobs=True + ) + + log_pf_trajectories_cum = model.cumulative_logprobs( + trajectories, log_pf_trajectories_ + ) + log_pb_trajectories_cum = model.cumulative_logprobs( + trajectories, log_pb_trajectories_ + ) + + log_state_flows_ = model.calculate_log_state_flows( + env, trajectories, log_pf_trajectories_ + ) + sink_states_mask_, is_terminal_mask_ = model.calculate_masks( + log_state_flows_, trajectories + ) + + flattening_masks_orig = [] + scores_orig = [] + for i in range(1, 1 + trajectories.max_length): + preds = model.calculate_preds(log_pf_trajectories_cum, log_state_flows_, i) + targets = model.calculate_targets( + trajectories, + preds, + log_pb_trajectories_cum, + log_state_flows_, + is_terminal_mask_, + sink_states_mask_, + i, + ) + + flattening_mask = trajectories.terminating_idx.lt( + torch.arange( + i, + trajectories.max_length + 1, + device=trajectories.terminating_idx.device, + ).unsqueeze(-1) + ) + + flat_preds = preds[~flattening_mask] + if model.debug and torch.any(torch.isnan(flat_preds)): + raise ValueError("NaN in preds") + + flat_targets = targets[~flattening_mask] + if model.debug and torch.any(torch.isnan(flat_targets)): + raise ValueError("NaN in targets") + + flattening_masks_orig.append(flattening_mask) + scores_orig.append(preds - targets) + + return scores_orig, flattening_masks_orig + + +# ------------------------- +# DB / ModDB (get_scores) +# ------------------------- + + +class _DBDummyStates: + def __init__(self, tensor: torch.Tensor, is_sink_state: torch.Tensor | None = None): + self.tensor = tensor + self.is_sink_state = ( + is_sink_state + if is_sink_state is not None + else torch.zeros(tensor.shape[0], dtype=torch.bool, device=tensor.device) + ) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_DBDummyStates": + return _DBDummyStates(self.tensor[idx], self.is_sink_state[idx]) + + @property + def batch_shape(self) -> torch.Size: + return self.tensor.shape[:-1] + + @property + def device(self) -> torch.device: + return self.tensor.device + + +class _DBDummyActions: + exit_action = torch.tensor(0, dtype=torch.long) + + def __init__(self, tensor: torch.Tensor): + self.tensor = tensor + self.is_exit = torch.zeros_like(tensor, dtype=torch.bool) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_DBDummyActions": + return _DBDummyActions(self.tensor[idx]) + + @property + def batch_shape(self) -> torch.Size: + return self.tensor.shape + + +class _DBDummyTransitions: + def __init__( + self, + states: _DBDummyStates, + next_states: _DBDummyStates, + actions: _DBDummyActions, + is_terminating: torch.Tensor, + log_rewards: torch.Tensor, + conditions: torch.Tensor | None = None, + ): + self.states = states + self.next_states = next_states + self.actions = actions + self.is_terminating = is_terminating + self.is_backward = False + self.conditions = conditions + self.log_rewards = log_rewards + self.device = states.device + self.n_transitions = len(states) + + def __len__(self) -> int: + return self.n_transitions + + +class _DBDummyEnv: + def __init__(self, log_reward_fn: Callable[[Any, Any | None], torch.Tensor]): + self._log_reward_fn = log_reward_fn + + def log_reward(self, states: Any, conditions: Any | None = None) -> torch.Tensor: + return self._log_reward_fn(states, conditions) + + +def _db_build_model_and_data( + n_transitions: int, + seed: int = 0, + device: str | torch.device = "cpu", + forward_looking: bool = False, +) -> Tuple[DBGFlowNet, _DBDummyEnv, _DBDummyTransitions]: + torch.manual_seed(seed) + device = torch.device(device) + + states_tensor = torch.randn(n_transitions, 4, device=device) + next_states_tensor = torch.randn(n_transitions, 4, device=device) + is_sink_state = torch.zeros(n_transitions, dtype=torch.bool, device=device) + states = _DBDummyStates(states_tensor, is_sink_state=is_sink_state) + next_states = _DBDummyStates(next_states_tensor, is_sink_state=is_sink_state.clone()) + + is_terminating = torch.zeros(n_transitions, dtype=torch.bool, device=device) + is_terminating[::3] = True + actions = _DBDummyActions( + torch.zeros(n_transitions, dtype=torch.long, device=device) + ) + log_rewards = torch.randn(n_transitions, device=device) + + transitions = _DBDummyTransitions( + states=states, + next_states=next_states, + actions=actions, + is_terminating=is_terminating, + log_rewards=log_rewards, + conditions=None, + ) + + log_pf = torch.randn(n_transitions, device=device) + log_pb = torch.randn(n_transitions, device=device) + logF_states = torch.randn(n_transitions, 1, device=device) + logF_next = torch.randn(n_transitions, 1, device=device) + log_reward_states = torch.randn(n_transitions, device=device) + log_reward_next = torch.randn(n_transitions, device=device) + + def get_pfs_and_pbs_stub(_self, _transitions, recalculate_all_logprobs: bool = True): + return log_pf, log_pb + + def logF_stub(_self, s, _conditions=None): + length = len(s) + if length == n_transitions: + return logF_states + return logF_next[:length] + + def log_reward_stub(_states, _conditions=None): + length = len(_states) + if length == n_transitions: + return log_reward_states + return log_reward_next[:length] + + env = _DBDummyEnv(log_reward_stub) + + model = DBGFlowNet.__new__(DBGFlowNet) + torch.nn.Module.__init__(model) + model.debug = False + model.forward_looking = forward_looking + model.log_reward_clip_min = -float("inf") + model.get_pfs_and_pbs = MethodType(get_pfs_and_pbs_stub, model) + model.logF = MethodType(logF_stub, model) + + return model, env, transitions + + +def _db_original_get_scores( + model: DBGFlowNet, + env: _DBDummyEnv, + transitions: _DBDummyTransitions, + recalculate_all_logprobs: bool = True, +) -> torch.Tensor: + if model.debug and transitions.is_backward: + raise ValueError("Backward transitions are not supported") + + states = transitions.states + transitions.actions + + if len(states) == 0: + return torch.tensor(0.0, device=transitions.device) + + log_pf, log_pb = model.get_pfs_and_pbs( + transitions, recalculate_all_logprobs=recalculate_all_logprobs + ) + + if transitions.conditions is not None: + with has_conditions_exception_handler("logF", model.logF): + log_F_s = model.logF(states, transitions.conditions).squeeze(-1) + else: + with no_conditions_exception_handler("logF", model.logF): + log_F_s = model.logF(states).squeeze(-1) + + log_F_s_next = torch.zeros_like(log_F_s) + is_terminating = transitions.is_terminating + is_intermediate = ~is_terminating + + interm_next_states = transitions.next_states[is_intermediate] + if transitions.conditions is not None: + with has_conditions_exception_handler("logF", model.logF): + log_F_s_next[is_intermediate] = model.logF( + interm_next_states, + transitions.conditions[is_intermediate], + ).squeeze(-1) + else: + with no_conditions_exception_handler("logF", model.logF): + log_F_s_next[is_intermediate] = model.logF(interm_next_states).squeeze(-1) + + if model.forward_looking: + if transitions.conditions is not None: + log_rewards_state = env.log_reward(states, transitions.conditions) + log_rewards_next = env.log_reward( + interm_next_states, transitions.conditions[is_intermediate] + ) + else: + log_rewards_state = env.log_reward(states) + log_rewards_next = env.log_reward(interm_next_states) + + log_rewards_state = log_rewards_state.clamp_min(model.log_reward_clip_min) + log_rewards_next = log_rewards_next.clamp_min(model.log_reward_clip_min) + + log_F_s = log_F_s + log_rewards_state + log_F_s_next[is_intermediate] = log_F_s_next[is_intermediate] + log_rewards_next + + log_rewards = transitions.log_rewards + log_rewards = log_rewards.clamp_min(model.log_reward_clip_min) + log_F_s_next[is_terminating] = log_rewards[is_terminating] + + preds = log_pf + log_F_s + targets = log_pb + log_F_s_next + scores = preds - targets + return scores + + +# ----- Modified DB ----- + + +class _ModDBDummyStates: + def __init__(self, tensor: torch.Tensor, is_sink_state: torch.Tensor | None = None): + self.tensor = tensor + self.is_sink_state = ( + is_sink_state + if is_sink_state is not None + else torch.zeros(tensor.shape[0], dtype=torch.bool, device=tensor.device) + ) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_ModDBDummyStates": + return _ModDBDummyStates(self.tensor[idx], self.is_sink_state[idx]) + + @property + def device(self) -> torch.device: + return self.tensor.device + + +class _ModDBDummyActions: + exit_action = torch.tensor(0, dtype=torch.long) + + def __init__(self, tensor: torch.Tensor, is_exit: torch.Tensor | None = None): + self.tensor = tensor + self.is_exit = ( + is_exit + if is_exit is not None + else torch.zeros_like(tensor, dtype=torch.bool) + ) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __getitem__(self, idx) -> "_ModDBDummyActions": + return _ModDBDummyActions(self.tensor[idx], self.is_exit[idx]) + + +class _ModDBDummyTransitions: + def __init__( + self, + states: _ModDBDummyStates, + next_states: _ModDBDummyStates, + actions: _ModDBDummyActions, + all_log_rewards: torch.Tensor, + is_backward: bool = False, + log_probs: torch.Tensor | None = None, + has_log_probs: bool = False, + conditions: torch.Tensor | None = None, + ): + self.states = states + self.next_states = next_states + self.actions = actions + self.all_log_rewards = all_log_rewards + self.is_backward = is_backward + self.log_probs = log_probs + self.has_log_probs = has_log_probs + self.conditions = conditions + self.device = states.device + self.n_transitions = len(states) + + def __len__(self) -> int: + return self.n_transitions + + def __getitem__(self, idx) -> "_ModDBDummyTransitions": + return _ModDBDummyTransitions( + self.states[idx], + self.next_states[idx], + self.actions[idx], + self.all_log_rewards[idx], + self.is_backward, + self.log_probs[idx] if self.log_probs is not None else None, + self.has_log_probs, + self.conditions[idx] if self.conditions is not None else None, + ) + + +class _ModDBFakeDist: + def __init__(self, log_action: torch.Tensor, log_exit: torch.Tensor): + self._log_action = log_action + self._log_exit = log_exit + + def log_prob(self, action_tensor: torch.Tensor) -> torch.Tensor: + n = action_tensor.shape[0] + if action_tensor.shape == self._log_exit.shape: + return self._log_exit + return self._log_action[:n] + + +class _ModDBDummyEstimator: + def __init__( + self, + log_action: torch.Tensor, + log_exit: torch.Tensor, + ): + self._log_action = log_action + self._log_exit = log_exit + + def __call__(self, states: _ModDBDummyStates, conditions=None): + return None + + def to_probability_distribution(self, states: _ModDBDummyStates, module_output=None): + return _ModDBFakeDist(self._log_action, self._log_exit) + + +def _moddb_build_model_and_data( + n_transitions: int, + seed: int = 0, + device: str | torch.device = "cpu", +) -> Tuple[ModifiedDBGFlowNet, _ModDBDummyTransitions]: + torch.manual_seed(seed) + device = torch.device(device) + + states_tensor = torch.randn(n_transitions, 4, device=device) + next_states_tensor = torch.randn(n_transitions, 4, device=device) + + is_sink_state = torch.zeros(n_transitions, dtype=torch.bool, device=device) + is_sink_state[::4] = True + states = _ModDBDummyStates( + states_tensor, is_sink_state=torch.zeros_like(is_sink_state) + ) + next_states = _ModDBDummyStates(next_states_tensor, is_sink_state=is_sink_state) + + actions_tensor = torch.randint(0, 5, (n_transitions,), device=device) + is_exit = torch.zeros_like(actions_tensor, dtype=torch.bool) + actions = _ModDBDummyActions(actions_tensor, is_exit=is_exit) + + all_log_rewards = torch.randn(n_transitions, 2, device=device) + + transitions = _ModDBDummyTransitions( + states=states, + next_states=next_states, + actions=actions, + all_log_rewards=all_log_rewards, + has_log_probs=False, + log_probs=None, + conditions=None, + ) + + non_sink_count = int((~is_sink_state).sum().item()) + log_pf_action = torch.randn(non_sink_count, device=device) + log_pf_exit = torch.randn(non_sink_count, device=device) + log_pf_exit_next = torch.randn(non_sink_count, device=device) + log_pb_action = torch.randn(non_sink_count, device=device) + + pf_estimator = _ModDBDummyEstimator(log_pf_action, log_pf_exit) + pb_estimator = _ModDBDummyEstimator(log_pb_action, log_pf_exit_next) + + model = ModifiedDBGFlowNet.__new__(ModifiedDBGFlowNet) + torch.nn.Module.__init__(model) + model.debug = False + model.constant_pb = False + model.pf = pf_estimator + model.pb = pb_estimator + model.log_reward_clip_min = -float("inf") + + return model, transitions + + +def _moddb_original_get_scores( + model: ModifiedDBGFlowNet, + transitions: _ModDBDummyTransitions, + recalculate_all_logprobs: bool = True, +) -> torch.Tensor: + if model.debug and transitions.is_backward: + raise ValueError("Backward transitions are not supported") + + if len(transitions) == 0: + return torch.tensor(0.0, device=transitions.device) + + mask = ~transitions.next_states.is_sink_state + states = transitions.states[mask] + valid_next_states = transitions.next_states[mask] + actions = transitions.actions[mask] + all_log_rewards = transitions.all_log_rewards[mask] + + if transitions.conditions is not None: + with has_conditions_exception_handler("pf", model.pf): + module_output = model.pf(states, transitions.conditions[mask]) + else: + with no_conditions_exception_handler("pf", model.pf): + module_output = model.pf(states) + + if len(states) == 0: + return torch.tensor(0.0, device=transitions.device) + + pf_dist = model.pf.to_probability_distribution(states, module_output) + + if transitions.has_log_probs and not recalculate_all_logprobs: + valid_log_pf_actions = transitions[mask].log_probs + assert valid_log_pf_actions is not None + else: + valid_log_pf_actions = pf_dist.log_prob(actions.tensor) + exit_action_tensor = actions.__class__.exit_action.to( + actions.tensor.device, dtype=actions.tensor.dtype + ).expand_as(actions.tensor) + valid_log_pf_s_exit = pf_dist.log_prob(exit_action_tensor) + + if transitions.conditions is not None: + with has_conditions_exception_handler("pf", model.pf): + module_output = model.pf(valid_next_states, transitions.conditions[mask]) + else: + with no_conditions_exception_handler("pf", model.pf): + module_output = model.pf(valid_next_states) + + valid_log_pf_s_prime_exit = model.pf.to_probability_distribution( + valid_next_states, module_output + ).log_prob(exit_action_tensor[: len(valid_next_states)]) + + non_exit_actions = actions[~actions.is_exit] + + if model.pb is not None: + if transitions.conditions is not None: + with has_conditions_exception_handler("pb", model.pb): + module_output = model.pb(valid_next_states, transitions.conditions[mask]) + else: + with no_conditions_exception_handler("pb", model.pb): + module_output = model.pb(valid_next_states) + + valid_log_pb_actions = model.pb.to_probability_distribution( + valid_next_states, module_output + ).log_prob(non_exit_actions.tensor) + else: + valid_log_pb_actions = torch.zeros_like(valid_log_pf_s_exit) + + preds = all_log_rewards[:, 0] + valid_log_pf_actions + valid_log_pf_s_prime_exit + targets = all_log_rewards[:, 1] + valid_log_pb_actions + valid_log_pf_s_exit + + scores = preds - targets + return scores + + +# ------------------------- +# Helpers +# ------------------------- + + +def _select_dtype(name: str, device: torch.device) -> torch.dtype: + if name == "fp32": + return torch.float32 + if name == "fp16": + if device.type != "cuda": + raise ValueError("fp16 is CUDA-only for this benchmark") + return torch.float16 + if name == "bf16": + return torch.bfloat16 + raise ValueError(f"Unsupported dtype {name}") + + +def _scale_int(value: int, scale: float) -> int: + return max(1, int(round(value * scale))) + + +def _scale_pair(pair: Tuple[int, int], scale: float) -> Tuple[int, int]: + return _scale_int(pair[0], scale), _scale_int(pair[1], scale) + + +def _time_fn(fn: Callable[[], Any]) -> float: + t = benchmark.Timer( + stmt="fn()", + globals={"fn": fn}, + num_threads=torch.get_num_threads(), + ).blocked_autorange(min_run_time=0.5) + return t.median + + +def _maybe_compile( + fn: Callable[..., Any] | None, enabled: bool +) -> Callable[..., Any] | None: + if fn is None or not enabled: + return None + return torch.compile(fn, fullgraph=False, dynamic=False, mode="reduce-overhead") + + +def _format_ms(value: float | None) -> str: + if value is None: + return " - " + return f"{value*1e3:10.3f}" + + +def _run_with_compile_variants( + eager_fn: Callable[[], Any], + compile_enabled: bool, +) -> tuple[float, float | None]: + t_eager = _time_fn(eager_fn) + compiled_fn = _maybe_compile(eager_fn, compile_enabled) + t_compiled = _time_fn(compiled_fn) if compiled_fn is not None else None + return t_eager, t_compiled + + +def _run_tb_or_lpv( + variant: str, + sizes: Iterable[int], + T: int, + device: torch.device, + dtype: torch.dtype, + repeat: int, + compile_enabled: bool, +): + print(f"\n=== {variant.upper()} loss ===") + print( + f"{'N':>10} {'chk':>4} {'orig(ms)':>10} {'orig+c(ms)':>10} {'curr(ms)':>10} {'curr+c(ms)':>10} {'spd':>6} {'spd_c':>6}" + ) + for size in sizes: + N = int(size) + # correctness (once, not timed) + if variant == "tb": + model, trajectories, log_pf, log_pb = _build_tb(T, N, device, dtype) + orig_val = _tb_original_loss(model, trajectories, log_pf, log_pb) + curr_val = model.loss(None, trajectories, recalculate_all_logprobs=False) # type: ignore[arg-type] + else: + model, trajectories, log_pf, log_pb = _build_lpv(T, N, device, dtype) + orig_val = _lpv_original_loss(model, trajectories, log_pf, log_pb) + curr_val = model.loss(None, trajectories, recalculate_all_logprobs=False) # type: ignore[arg-type] + diff = (orig_val - curr_val).abs().max().item() + tol = 1e-6 if dtype == torch.float32 else 5e-4 + status = "PASS" if diff <= tol else "FAIL" + + orig_times = [] + origc_times = [] + curr_times = [] + currc_times = [] + for _ in range(repeat): + t_orig, t_origc = _run_with_compile_variants( + eager_fn=lambda: ( + _tb_original_loss(model, trajectories, log_pf, log_pb) + if variant == "tb" + else _lpv_original_loss(model, trajectories, log_pf, log_pb) + ), + compile_enabled=compile_enabled, + ) + t_curr, t_currc = _run_with_compile_variants( + eager_fn=lambda: model.loss( + None, trajectories, recalculate_all_logprobs=False # type: ignore[arg-type] + ), + compile_enabled=compile_enabled, + ) + orig_times.append(t_orig) + curr_times.append(t_curr) + if t_origc is not None: + origc_times.append(t_origc) + if t_currc is not None: + currc_times.append(t_currc) + + t_orig_ms = torch.tensor(orig_times).median().item() * 1e3 + t_origc_ms = ( + torch.tensor(origc_times).median().item() * 1e3 if origc_times else None + ) + t_curr_ms = torch.tensor(curr_times).median().item() * 1e3 + t_currc_ms = ( + torch.tensor(currc_times).median().item() * 1e3 if currc_times else None + ) + speedup = t_orig_ms / t_curr_ms if t_curr_ms > 0 else float("inf") + speedup_c = ( + t_orig_ms / t_currc_ms if t_currc_ms and t_currc_ms > 0 else float("inf") + ) + print( + f"{N:10d} {status:>4} {_format_ms(t_orig_ms)} {_format_ms(t_origc_ms)} {_format_ms(t_curr_ms)} {_format_ms(t_currc_ms)} {speedup:6.2f} {speedup_c:6.2f}" + ) + + +def _run_subtb( + sizes: Iterable[Tuple[int, int]], + device: torch.device, + repeat: int, + compile_enabled: bool, +): + print("\n=== SubTB get_scores ===") + print( + f"{'size':>12} {'chk':>4} {'orig(ms)':>10} {'orig+c(ms)':>10} {'curr(ms)':>10} {'curr+c(ms)':>10} {'spd':>6} {'spd_c':>6}" + ) + for max_len, n_traj in sizes: + model, trajectories, _, _ = _subtb_build_model_and_data( + max_len, n_traj, device=device + ) + env_obj: Any = object() + + # correctness (once, not timed) + orig_scores, _ = _subtb_original_get_scores(model, env_obj, trajectories) + curr_scores, _ = model.get_scores(env_obj, trajectories) # type: ignore[arg-type] + max_abs = max( + (orig - curr).abs().max().item() + for orig, curr in zip(orig_scores, curr_scores) + ) + tol = 1e-6 + status = "PASS" if max_abs <= tol else "FAIL" + + orig_times = [] + origc_times = [] + curr_times = [] + currc_times = [] + for _ in range(repeat): + t_orig, t_origc = _run_with_compile_variants( + eager_fn=lambda: _subtb_original_get_scores( + model, env_obj, trajectories + ), + compile_enabled=compile_enabled, + ) + t_curr, t_currc = _run_with_compile_variants( + eager_fn=lambda: model.get_scores(env_obj, trajectories), # type: ignore[arg-type] + compile_enabled=compile_enabled, + ) + orig_times.append(t_orig) + curr_times.append(t_curr) + if t_origc is not None: + origc_times.append(t_origc) + if t_currc is not None: + currc_times.append(t_currc) + + t_orig_ms = torch.tensor(orig_times).median().item() * 1e3 + t_origc_ms = ( + torch.tensor(origc_times).median().item() * 1e3 if origc_times else None + ) + t_curr_ms = torch.tensor(curr_times).median().item() * 1e3 + t_currc_ms = ( + torch.tensor(currc_times).median().item() * 1e3 if currc_times else None + ) + speedup = t_orig_ms / t_curr_ms if t_curr_ms > 0 else float("inf") + speedup_c = ( + t_orig_ms / t_currc_ms if t_currc_ms and t_currc_ms > 0 else float("inf") + ) + print( + f"{max_len}x{n_traj:5d} {status:>4} {_format_ms(t_orig_ms)} {_format_ms(t_origc_ms)} {_format_ms(t_curr_ms)} {_format_ms(t_currc_ms)} {speedup:6.2f} {speedup_c:6.2f}" + ) + + +def _run_db( + sizes: Iterable[int], + device: torch.device, + repeat: int, + compile_enabled: bool, + forward_looking: bool, +): + print("\n=== DB get_scores ===") + print( + f"{'n_trans':>10} {'chk':>4} {'orig(ms)':>10} {'orig+c(ms)':>10} {'curr(ms)':>10} {'curr+c(ms)':>10} {'spd':>6} {'spd_c':>6}" + ) + for n_transitions in sizes: + model, env, transitions = _db_build_model_and_data( + n_transitions=n_transitions, + device=device, + forward_looking=forward_looking, + ) + + # correctness (once, not timed) + orig = _db_original_get_scores( + model, env, transitions, recalculate_all_logprobs=True + ) + curr = model.get_scores(env, transitions) # type: ignore[arg-type] + max_abs = (orig - curr).abs().max().item() + tol = 1e-6 + status = "PASS" if max_abs <= tol else "FAIL" + + orig_times = [] + origc_times = [] + curr_times = [] + currc_times = [] + for _ in range(repeat): + t_orig, t_origc = _run_with_compile_variants( + eager_fn=lambda: _db_original_get_scores( + model, env, transitions, recalculate_all_logprobs=True + ), + compile_enabled=compile_enabled, + ) + t_curr, t_currc = _run_with_compile_variants( + eager_fn=lambda: model.get_scores(env, transitions), # type: ignore[arg-type] + compile_enabled=compile_enabled, + ) + orig_times.append(t_orig) + curr_times.append(t_curr) + if t_origc is not None: + origc_times.append(t_origc) + if t_currc is not None: + currc_times.append(t_currc) + + t_orig_ms = torch.tensor(orig_times).median().item() * 1e3 + t_origc_ms = ( + torch.tensor(origc_times).median().item() * 1e3 if origc_times else None + ) + t_curr_ms = torch.tensor(curr_times).median().item() * 1e3 + t_currc_ms = ( + torch.tensor(currc_times).median().item() * 1e3 if currc_times else None + ) + speedup = t_orig_ms / t_curr_ms if t_curr_ms > 0 else float("inf") + speedup_c = ( + t_orig_ms / t_currc_ms if t_currc_ms and t_currc_ms > 0 else float("inf") + ) + print( + f"{n_transitions:10d} {status:>4} {_format_ms(t_orig_ms)} {_format_ms(t_origc_ms)} {_format_ms(t_curr_ms)} {_format_ms(t_currc_ms)} {speedup:6.2f} {speedup_c:6.2f}" + ) + + +def _run_moddb( + sizes: Iterable[int], + device: torch.device, + repeat: int, + compile_enabled: bool, +): + print("\n=== Modified DB get_scores ===") + print( + f"{'n_trans':>10} {'chk':>4} {'orig(ms)':>10} {'orig+c(ms)':>10} {'curr(ms)':>10} {'curr+c(ms)':>10} {'spd':>6} {'spd_c':>6}" + ) + for n_transitions in sizes: + model, transitions = _moddb_build_model_and_data( + n_transitions=n_transitions, + device=device, + ) + + # correctness (once, not timed) + orig = _moddb_original_get_scores( + model, transitions, recalculate_all_logprobs=True + ) + curr = model.get_scores(transitions) # type: ignore[arg-type] + max_abs = (orig - curr).abs().max().item() + tol = 1e-6 + status = "PASS" if max_abs <= tol else "FAIL" + + orig_times = [] + origc_times = [] + curr_times = [] + currc_times = [] + for _ in range(repeat): + t_orig, t_origc = _run_with_compile_variants( + eager_fn=lambda: _moddb_original_get_scores( + model, transitions, recalculate_all_logprobs=True + ), + compile_enabled=compile_enabled, + ) + t_curr, t_currc = _run_with_compile_variants( + eager_fn=lambda: model.get_scores(transitions), # type: ignore[arg-type] + compile_enabled=compile_enabled, + ) + orig_times.append(t_orig) + curr_times.append(t_curr) + if t_origc is not None: + origc_times.append(t_origc) + if t_currc is not None: + currc_times.append(t_currc) + + t_orig_ms = torch.tensor(orig_times).median().item() * 1e3 + t_origc_ms = ( + torch.tensor(origc_times).median().item() * 1e3 if origc_times else None + ) + t_curr_ms = torch.tensor(curr_times).median().item() * 1e3 + t_currc_ms = ( + torch.tensor(currc_times).median().item() * 1e3 if currc_times else None + ) + speedup = t_orig_ms / t_curr_ms if t_curr_ms > 0 else float("inf") + speedup_c = ( + t_orig_ms / t_currc_ms if t_currc_ms and t_currc_ms > 0 else float("inf") + ) + print( + f"{n_transitions:10d} {status:>4} {_format_ms(t_orig_ms)} {_format_ms(t_origc_ms)} {_format_ms(t_curr_ms)} {_format_ms(t_currc_ms)} {speedup:6.2f} {speedup_c:6.2f}" + ) + + +# ------------------------- +# Main +# ------------------------- + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--device", default="cpu", help="Device to run on (cpu, mps, cuda)." + ) + parser.add_argument( + "--dtype", + choices=["fp32", "fp16", "bf16"], + default="fp32", + help="Floating dtype for TB/LPV (others use fp32).", + ) + parser.add_argument( + "--size-scale", + type=float, + default=1.0, + help="Scale factor applied to all embedded base sizes.", + ) + parser.add_argument( + "--repeat", + type=int, + default=1, + help="Repeat each size multiple times; medians are reported.", + ) + parser.add_argument( + "--compile", + default=True, + action=argparse.BooleanOptionalAction, + help="Enable torch.compile for both original and current functions.", + ) + parser.add_argument( + "--forward-looking", + action="store_true", + help="Enable forward-looking reward path for DB benchmark.", + ) + return parser.parse_args() + + +def main(): + args = _parse_args() + device = torch.device(args.device) + if device.type == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but not available") + dtype = _select_dtype(args.dtype, device) + + # Base sizes + base_tb_sizes = [10240, 40960, 163840] + base_T = 64 + base_subtb_sizes = [(80, 640), (160, 1280), (320, 2560)] + base_db_sizes = [65536, 131072, 262144] + base_moddb_sizes = [65536, 131072, 262144] + + tb_sizes = [_scale_int(n, args.size_scale) for n in base_tb_sizes] + subtb_sizes = [_scale_pair(p, args.size_scale) for p in base_subtb_sizes] + db_sizes = [_scale_int(n, args.size_scale) for n in base_db_sizes] + moddb_sizes = [_scale_int(n, args.size_scale) for n in base_moddb_sizes] + + print("Benchmarking TB/LPV/SubTB/DB/ModDB in sequence") + print(f"torch version: {torch.__version__}") + print(f"device: {device}") + print(f"dtype (TB/LPV): {dtype}") + print(f"num threads: {torch.get_num_threads()}") + print(f"size-scale: {args.size_scale}") + print(f"compile: {args.compile}") + print(f"repeat: {args.repeat}") + print(f"forward-looking (DB): {args.forward_looking}") + print() + print( + "Columns: original, original+compile, current, current+compile, speedup vs original (eager and compiled)." + ) + + _run_tb_or_lpv( + variant="tb", + sizes=tb_sizes, + T=base_T, + device=device, + dtype=dtype, + repeat=args.repeat, + compile_enabled=args.compile, + ) + _run_tb_or_lpv( + variant="lpv", + sizes=tb_sizes, + T=base_T, + device=device, + dtype=dtype, + repeat=args.repeat, + compile_enabled=args.compile, + ) + _run_subtb( + sizes=subtb_sizes, + device=device, + repeat=args.repeat, + compile_enabled=args.compile, + ) + _run_db( + sizes=db_sizes, + device=device, + repeat=args.repeat, + compile_enabled=args.compile, + forward_looking=args.forward_looking, + ) + _run_moddb( + sizes=moddb_sizes, + device=device, + repeat=args.repeat, + compile_enabled=args.compile, + ) + + +if __name__ == "__main__": + main()