From b3e590893d22dbc4b9eec36cce07c48f110f03c1 Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Thu, 25 Sep 2025 17:15:20 +0100 Subject: [PATCH 1/4] support backward_logprobs and backward_estimator_outputs in Trajectories --- src/gfn/containers/trajectories.py | 243 ++++++++++++++----- src/gfn/containers/transitions.py | 32 ++- src/gfn/samplers.py | 74 +++--- tutorials/examples/train_hypergrid_buffer.py | 2 +- 4 files changed, 250 insertions(+), 101 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 49a6980f..f69a5702 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -55,7 +55,9 @@ def __init__( is_backward: bool = False, log_rewards: torch.Tensor | None = None, log_probs: torch.Tensor | None = None, + backward_log_probs: torch.Tensor | None = None, estimator_outputs: torch.Tensor | None = None, + backward_estimator_outputs: torch.Tensor | None = None, ) -> None: """Initializes a Trajectories instance. @@ -74,8 +76,12 @@ def __init__( log rewards of the trajectories. If None, computed on the fly when needed. log_probs: Optional tensor of shape (max_length, n_trajectories) indicating the log probabilities of the trajectories' actions. + backward_log_probs: Optional tensor of shape (max_length, n_trajectories) indicating + the backward log probabilities of the trajectories' actions. estimator_outputs: Optional tensor of shape (max_length, n_trajectories, ...) containing outputs of a function approximator for each step. + backward_estimator_outputs: Optional tensor of shape (max_length, n_trajectories, ...) + containing outputs of a function approximator for each backward step. Note: When states and actions are not None, the Trajectories is initialized as an empty container that can be populated later with the `extend` method. @@ -97,7 +103,9 @@ def __init__( terminating_idx, log_rewards, log_probs, + backward_log_probs, estimator_outputs, + backward_estimator_outputs, ]: if tensor is not None: ensure_same_device(tensor.device, device) @@ -142,6 +150,12 @@ def __init__( and self.log_probs.is_floating_point() ) + self.backward_log_probs = backward_log_probs + assert self.backward_log_probs is None or ( + self.backward_log_probs.shape == self.actions.batch_shape + and self.backward_log_probs.is_floating_point() + ) + self.estimator_outputs = estimator_outputs assert self.estimator_outputs is None or ( self.estimator_outputs.shape[: len(self.states.batch_shape)] @@ -149,6 +163,13 @@ def __init__( and self.estimator_outputs.is_floating_point() ) + self.backward_estimator_outputs = backward_estimator_outputs + assert self.backward_estimator_outputs is None or ( + self.backward_estimator_outputs.shape[: len(self.states.batch_shape)] + == self.actions.batch_shape + and self.backward_estimator_outputs.is_floating_point() + ) + def __repr__(self) -> str: """Returns a string representation of the Trajectories container. @@ -223,11 +244,14 @@ def max_length(self) -> int: @property def terminating_states(self) -> States: - """The terminating states of the trajectories. + """The terminating states of the trajectories. If backward, the terminating states + are in 0-th position. Returns: The terminating states. """ + if self.is_backward: + return self.states[0, torch.arange(self.n_trajectories)] return self.states[self.terminating_idx - 1, torch.arange(self.n_trajectories)] @property @@ -265,34 +289,43 @@ def __getitem__( index = [index] terminating_idx = self.terminating_idx[index] new_max_length = terminating_idx.max().item() if len(terminating_idx) > 0 else 0 - states = self.states[:, index] + states = self.states[: 1 + new_max_length, index] conditioning = ( - self.conditioning[:, index] if self.conditioning is not None else None + self.conditioning[: 1 + new_max_length, index] + if self.conditioning is not None + else None ) - actions = self.actions[:, index] - states = states[: 1 + new_max_length] - actions = actions[:new_max_length] - if self.log_probs is not None: - log_probs = self.log_probs[:, index] - log_probs = log_probs[:new_max_length] - else: - log_probs = None + actions = self.actions[:new_max_length, index] log_rewards = self._log_rewards[index] if self._log_rewards is not None else None - if self.estimator_outputs is not None: - # TODO: Is there a safer way to index self.estimator_outputs for - # for n-dimensional estimator outputs? - # - # First we index along the first dimension of the estimator outputs. - # This can be thought of as the instance dimension, and is - # compatible with all supported indexing approaches (dim=1). - # All dims > 1 are not explicitly indexed unless the dimensionality - # of `index` matches all dimensions of `estimator_outputs` aside - # from the first (trajectory) dimension. - estimator_outputs = self.estimator_outputs[:, index] - # Next we index along the trajectory length (dim=0) - estimator_outputs = estimator_outputs[:new_max_length] - else: - estimator_outputs = None + log_probs = ( + self.log_probs[:new_max_length, index] + if self.log_probs is not None + else None + ) + backward_log_probs = ( + self.backward_log_probs[:new_max_length, index] + if self.backward_log_probs is not None + else None + ) + # TODO: Is there a safer way to index self.estimator_outputs for + # for n-dimensional estimator outputs? + # + # First we index along the first dimension of the estimator outputs. + # This can be thought of as the instance dimension, and is + # compatible with all supported indexing approaches (dim=1). + # All dims > 1 are not explicitly indexed unless the dimensionality + # of `index` matches all dimensions of `estimator_outputs` aside + # from the first (trajectory) dimension. + estimator_outputs = ( + self.estimator_outputs[:new_max_length, index] + if self.estimator_outputs is not None + else None + ) + backward_estimator_outputs = ( + self.backward_estimator_outputs[:new_max_length, index] + if self.backward_estimator_outputs is not None + else None + ) return Trajectories( env=self.env, @@ -303,7 +336,9 @@ def __getitem__( is_backward=self.is_backward, log_rewards=log_rewards, log_probs=log_probs, + backward_log_probs=backward_log_probs, estimator_outputs=estimator_outputs, + backward_estimator_outputs=backward_estimator_outputs, ) def extend(self, other: Trajectories) -> None: @@ -315,6 +350,11 @@ def extend(self, other: Trajectories) -> None: Args: Another Trajectories to append. """ + + assert ( + self.is_backward == other.is_backward + ), "Trajectories must be of the same direction." + if self.conditioning is not None: # TODO: Support the case raise NotImplementedError( @@ -333,6 +373,10 @@ def extend(self, other: Trajectories) -> None: self.log_probs = torch.full( size=(0, 0), fill_value=0.0, device=self.device ) + if other.backward_log_probs is not None: + self.backward_log_probs = torch.full( + size=(0, 0), fill_value=0.0, device=self.device + ) if other.estimator_outputs is not None: self.estimator_outputs = torch.full( size=(0, 0, *other.estimator_outputs.shape[2:]), @@ -340,6 +384,13 @@ def extend(self, other: Trajectories) -> None: dtype=other.estimator_outputs.dtype, device=self.device, ) + if other.backward_estimator_outputs is not None: + self.backward_estimator_outputs = torch.full( + size=(0, 0, *other.backward_estimator_outputs.shape[2:]), + fill_value=0.0, + dtype=other.backward_estimator_outputs.dtype, + device=self.device, + ) # TODO: The replay buffer is storing `dones` - this wastes a lot of space. self.actions.extend(other.actions) @@ -367,6 +418,17 @@ def extend(self, other: Trajectories) -> None: else: self.log_probs = None + if self.backward_log_probs is not None and other.backward_log_probs is not None: + self.backward_log_probs, other.backward_log_probs = pad_dim0_if_needed( + self.backward_log_probs, other.backward_log_probs, 0.0 + ) + self.backward_log_probs = torch.cat( + (self.backward_log_probs, other.backward_log_probs), dim=1 + ) + assert self.backward_log_probs.shape == self.actions.batch_shape + else: + self.backward_log_probs = None + # Do the same for estimator_outputs, but padding with -float("inf") instead of 0.0 if self.estimator_outputs is not None and other.estimator_outputs is not None: self.estimator_outputs, other.estimator_outputs = pad_dim0_if_needed( @@ -382,6 +444,26 @@ def extend(self, other: Trajectories) -> None: else: self.estimator_outputs = None + if ( + self.backward_estimator_outputs is not None + and other.backward_estimator_outputs is not None + ): + self.backward_estimator_outputs, other.backward_estimator_outputs = ( + pad_dim0_if_needed( + self.backward_estimator_outputs, other.backward_estimator_outputs + ) + ) + self.backward_estimator_outputs = torch.cat( + (self.backward_estimator_outputs, other.backward_estimator_outputs), + dim=1, + ) + assert ( + self.backward_estimator_outputs.shape[: len(self.actions.batch_shape)] + == self.actions.batch_shape + ) + else: + self.backward_estimator_outputs = None + def to_transitions(self) -> Transitions: """Returns a Transitions object from the current Trajectories. @@ -427,6 +509,11 @@ def to_transitions(self) -> Transitions: if self.log_probs is not None else None ) + backward_log_probs = ( + self.backward_log_probs[~self.actions.is_dummy] + if self.backward_log_probs is not None + else None + ) return Transitions( env=self.env, @@ -438,6 +525,7 @@ def to_transitions(self) -> Transitions: is_backward=self.is_backward, log_rewards=log_rewards, log_probs=log_probs, + backward_log_probs=backward_log_probs, # FIXME: Add estimator_outputs. ) @@ -527,19 +615,18 @@ def reverse_backward_trajectories(self) -> Trajectories: seq_lengths = self.terminating_idx # shape (n_trajectories,) max_len = int(seq_lengths.max().item()) - # Get actions and states - actions = self.actions # shape (max_len, n_trajectories *action_dim) + # Get tensors states = self.states # shape (max_len + 1, n_trajectories, *state_dim) - - # Initialize new actions and states - new_actions = self.env.Actions.make_dummy_actions( - (max_len + 1, len(self)), device=actions.device - ) - # shape (max_len + 1, n_trajectories, *action_dim) - new_states = self.env.States.make_sink_states( - (max_len + 2, len(self)), device=states.device - ) - # shape (max_len + 2, n_trajectories, *state_dim) + actions = self.actions # shape (max_len, n_trajectories *action_dim) + log_probs = self.log_probs # shape (max_len, n_trajectories) + backward_log_probs = self.backward_log_probs # shape (max_len, n_trajectories) + estimator_outputs = ( + self.estimator_outputs + ) # shape (max_len, n_trajectories, ...) + backward_estimator_outputs = ( + self.backward_estimator_outputs + ) # shape (max_len, n_trajectories, ...) + device = states.device # device should be the same for all tensors # Create helper indices and masks idx = ( @@ -556,24 +643,21 @@ def reverse_backward_trajectories(self) -> Trajectories: # version that operates directly in (time, trajectory, *) space. # ------------------------------------------------------------- - # 1. Reverse actions --------------------------------------------------- # Gather linear indices where the mask is valid time_idx, traj_idx = torch.nonzero(mask, as_tuple=True) # 1-D tensors src_time_idx = rev_idx[mask] # Corresponding source time indices - # Assign reversed actions - new_actions[time_idx, traj_idx] = actions[src_time_idx, traj_idx] - # Insert EXIT action right after the last real action of every trajectory - new_actions[seq_lengths, torch.arange(len(self), device=seq_lengths.device)] = ( - self.env.Actions.make_exit_actions((1,), device=actions.device) - ) - # 2. Reverse states ---------------------------------------------------- + # 1. Reverse states ---------------------------------------------------- + # Initialize new states + new_states = self.env.States.make_sink_states( + (max_len + 2, len(self)), device=device + ) # shape (max_len + 2, n_trajectories, *state_dim) # The last state of the backward trajectories must be s0. assert torch.all(states[-1].is_initial_state), "Last state must be s0" # First state of the forward trajectories is s0 for every trajectory new_states[0] = self.env.States.make_initial_states( - (len(self),), device=states.device + (len(self),), device=device ) # Broadcast over the trajectory dimension # We do not want to copy the last state (s0) from the backward trajectory. @@ -581,12 +665,64 @@ def reverse_backward_trajectories(self) -> Trajectories: new_states_data = new_states[1:-1] # shape (max_len, n_trajectories, *state_dim) new_states_data[time_idx, traj_idx] = states_excl_last[src_time_idx, traj_idx] + # 2. Reverse actions --------------------------------------------------- + # Initialize new actions + new_actions = self.env.Actions.make_dummy_actions( + (max_len + 1, len(self)), device=device + ) # shape (max_len + 1, n_trajectories, *action_dim) + # Assign reversed actions + new_actions[time_idx, traj_idx] = actions[src_time_idx, traj_idx] + # Insert EXIT action right after the last real action of every trajectory + new_actions[seq_lengths, torch.arange(len(self), device=device)] = ( + self.env.Actions.make_exit_actions((1,), device=device) + ) + # --------------------------------------------------------------------- - # new_actions / new_states already have the correct shapes - # new_actions: (max_len + 1, n_trajectories, *action_dim) + # new_states / new_actions already have the correct shapes # new_states: (max_len + 2, n_trajectories, *state_dim) + # new_actions: (max_len + 1, n_trajectories, *action_dim) # --------------------------------------------------------------------- + # 3. Reverse the others ------------------------------------------------ + # Reversing the other tensors are basically the same as reversing the actions. + new_log_probs = None + if log_probs is not None: + new_log_probs = torch.full( + (max_len + 1, len(self)), fill_value=0.0, device=device + ) # shape (max_len + 1, n_trajectories, *action_dim) + new_log_probs[time_idx, traj_idx] = log_probs[src_time_idx, traj_idx] + + new_estimator_outputs = None + if estimator_outputs is not None: + new_estimator_outputs = torch.full( + (max_len + 1, len(self), *estimator_outputs.shape[2:]), + fill_value=0.0, + device=device, + ) # shape (max_len + 1, n_trajectories, *action_dim) + new_estimator_outputs[time_idx, traj_idx] = estimator_outputs[ + src_time_idx, traj_idx + ] + + new_backward_log_probs = None + if backward_log_probs is not None: + new_backward_log_probs = torch.full( + (max_len + 1, len(self)), fill_value=0.0, device=device + ) # shape (max_len + 1, n_trajectories, *action_dim) + new_backward_log_probs[time_idx, traj_idx] = backward_log_probs[ + src_time_idx, traj_idx + ] + + new_backward_estimator_outputs = None + if backward_estimator_outputs is not None: + new_backward_estimator_outputs = torch.full( + (max_len + 1, len(self), *backward_estimator_outputs.shape[2:]), + fill_value=0.0, + device=device, + ) # shape (max_len + 1, n_trajectories, *action_dim) + new_backward_estimator_outputs[time_idx, traj_idx] = ( + backward_estimator_outputs[src_time_idx, traj_idx] + ) + reversed_trajectories = Trajectories( env=self.env, states=new_states, @@ -595,11 +731,10 @@ def reverse_backward_trajectories(self) -> Trajectories: terminating_idx=self.terminating_idx + 1, is_backward=False, log_rewards=self.log_rewards, - log_probs=None, # We can't simply pass the trajectories.log_probs - # Since `log_probs` is assumed to be the forward log probabilities. - # FIXME: To resolve this, we can save log_pfs and log_pbs in the - # trajectories object. - estimator_outputs=None, # Same as `log_probs`. + log_probs=new_log_probs, + backward_log_probs=new_backward_log_probs, + estimator_outputs=new_estimator_outputs, + backward_estimator_outputs=new_backward_estimator_outputs, ) return reversed_trajectories diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index c5254356..e8aabf62 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -37,6 +37,8 @@ class Transitions(Container): rewards of the transitions. log_probs: (Optional) Tensor of shape (n_transitions,) containing the log probabilities of the actions. + backward_log_probs: (Optional) Tensor of shape (n_transitions,) containing the + backward log probabilities of the actions. """ def __init__( @@ -50,6 +52,7 @@ def __init__( is_backward: bool = False, log_rewards: torch.Tensor | None = None, log_probs: torch.Tensor | None = None, + backward_log_probs: torch.Tensor | None = None, ): """Initializes a Transitions instance. @@ -70,6 +73,8 @@ def __init__( rewards for the transitions. If None, computed on the fly when needed. log_probs: Optional tensor of shape (n_transitions,) containing the log probabilities of the actions. + backward_log_probs: Optional tensor of shape (n_transitions,) containing the + backward log probabilities of the actions. Note: @@ -127,6 +132,12 @@ def __init__( and self.log_probs.is_floating_point() ) + self.backward_log_probs = backward_log_probs + assert self.backward_log_probs is None or ( + self.backward_log_probs.shape == self.actions.batch_shape + and self.backward_log_probs.is_floating_point() + ) + @property def device(self) -> torch.device: """The device on which the transitions are stored. @@ -264,7 +275,11 @@ def __getitem__( next_states = self.next_states[index] log_rewards = self._log_rewards[index] if self._log_rewards is not None else None log_probs = self.log_probs[index] if self.log_probs is not None else None - + backward_log_probs = ( + self.backward_log_probs[index] + if self.backward_log_probs is not None + else None + ) return Transitions( env=self.env, states=states, @@ -275,6 +290,7 @@ def __getitem__( is_backward=self.is_backward, log_rewards=log_rewards, log_probs=log_probs, + backward_log_probs=backward_log_probs, ) def extend(self, other: Transitions) -> None: @@ -305,6 +321,12 @@ def extend(self, other: Transitions) -> None: fill_value=0.0, device=self.device, ) + if other.backward_log_probs is not None: + self.backward_log_probs = torch.full( + size=(0,), + fill_value=0.0, + device=self.device, + ) assert len(self.states.batch_shape) == len(other.states.batch_shape) == 1 @@ -326,3 +348,11 @@ def extend(self, other: Transitions) -> None: self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=0) else: self.log_probs = None + + # Concatenate backward_log_probs of the trajectories. + if self.backward_log_probs is not None and other.backward_log_probs is not None: + self.backward_log_probs = torch.cat( + (self.backward_log_probs, other.backward_log_probs), dim=0 + ) + else: + self.backward_log_probs = None diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index bb89e240..58b72a22 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -36,6 +36,7 @@ def __init__(self, estimator: Estimator) -> None: probability distributions. """ self.estimator = estimator + self.is_backward = estimator.is_backward def sample_actions( self, @@ -152,9 +153,9 @@ def sample_trajectories( For backward trajectories, the reward is computed at the initial state (s0) rather than the terminal state (sf). """ - if self.estimator.is_backward: - # [ASSUMPTION] When backward sampling, all provided states are the - # terminating states (can be passed to log_reward fn) + if self.is_backward: + # [IMPORTANT ASSUMPTION] When backward sampling, all provided states are the + # *terminating* states (can be passed to log_reward fn) assert ( states is not None ), "When backward sampling, `states` must be provided" @@ -176,20 +177,12 @@ def sample_trajectories( assert states.batch_shape == conditioning.shape[: len(states.batch_shape)] ensure_same_device(states.device, conditioning.device) - dones = ( - states.is_initial_state - if self.estimator.is_backward - else states.is_sink_state - ) + dones = states.is_initial_state if self.is_backward else states.is_sink_state # Define dummy actions to avoid errors when stacking empty lists. trajectories_states: List[States] = [states] - trajectories_actions: List[Actions] = [ - env.actions_from_batch_shape((n_trajectories,)) - ] - trajectories_logprobs: List[torch.Tensor] = [ - torch.full((n_trajectories,), fill_value=0, device=device) - ] + trajectories_actions: List[Actions] = [] + trajectories_logprobs: List[torch.Tensor] = [] trajectories_terminating_idx = torch.zeros( n_trajectories, dtype=torch.long, device=device ) @@ -229,16 +222,15 @@ def sample_trajectories( all_estimator_outputs.append(estimator_outputs_padded) actions[~dones] = valid_actions + trajectories_actions.append(actions) if save_logprobs: assert ( actions_log_probs is not None ), "actions_log_probs should not be None when save_logprobs is True" log_probs[~dones] = actions_log_probs + trajectories_logprobs.append(log_probs) - trajectories_actions.append(actions) - trajectories_logprobs.append(log_probs) - - if self.estimator.is_backward: + if self.is_backward: new_states = env._backward_step(states, actions) else: new_states = env._step(states, actions) @@ -264,7 +256,7 @@ def sample_trajectories( # to filter out the already done ones. new_dones = ( new_states.is_initial_state - if self.estimator.is_backward + if self.is_backward else new_states.is_sink_state ) & ~dones trajectories_terminating_idx[new_dones] = step @@ -275,29 +267,15 @@ def sample_trajectories( # Stack all states and actions stacked_states = env.States.stack(trajectories_states) - stacked_actions = env.Actions.stack(trajectories_actions)[ - 1: - ] # Drop dummy action + stacked_actions = env.Actions.stack(trajectories_actions) stacked_logprobs = ( - torch.stack(trajectories_logprobs, dim=0)[1:] # Drop dummy logprob - if save_logprobs - else None + torch.stack(trajectories_logprobs, dim=0) if save_logprobs else None ) - # TODO: use torch.nested.nested_tensor(dtype, device, requires_grad). stacked_estimator_outputs = ( torch.stack(all_estimator_outputs, dim=0) if save_estimator_outputs else None ) - # If there are no logprobs or estimator outputs, set them to None. - # TODO: This is a hack to avoid errors when no logprobs or estimator outputs are - # saved. This bug was introduced when I changed the dtypes library-wide -- why - # is this happening? - if stacked_logprobs is not None and len(stacked_logprobs) == 0: - stacked_logprobs = None - if stacked_estimator_outputs is not None and len(stacked_estimator_outputs) == 0: - stacked_estimator_outputs = None - # Broadcast conditioning tensor to match states batch shape if needed if conditioning is not None: # The states have batch shape (max_length, n_trajectories) @@ -322,10 +300,16 @@ def sample_trajectories( conditioning=conditioning, actions=stacked_actions, terminating_idx=trajectories_terminating_idx, - is_backward=self.estimator.is_backward, + is_backward=self.is_backward, log_rewards=None, # will be calculated later - log_probs=stacked_logprobs, - estimator_outputs=stacked_estimator_outputs, + log_probs=stacked_logprobs if not self.is_backward else None, + backward_log_probs=stacked_logprobs if self.is_backward else None, + estimator_outputs=( + stacked_estimator_outputs if not self.is_backward else None + ), + backward_estimator_outputs=( + stacked_estimator_outputs if self.is_backward else None + ), ) return trajectories @@ -454,7 +438,7 @@ def local_search( (n_prevs).view(1, -1, 1).expand(-1, -1, *trajectories.states.state_shape), ).squeeze(0) recon_trajectories = super().sample_trajectories( - env, + env=env, states=env.states_from_tensor(junction_states_tsr), conditioning=conditioning, save_estimator_outputs=save_estimator_outputs, @@ -591,12 +575,12 @@ def sample_trajectories( and the improved trajectories from local search. """ trajectories = super().sample_trajectories( - env, - n, - states, - conditioning, - save_estimator_outputs, - save_logprobs or use_metropolis_hastings, + env=env, + n=n, + states=states, + conditioning=conditioning, + save_estimator_outputs=save_estimator_outputs, + save_logprobs=save_logprobs or use_metropolis_hastings, **policy_kwargs, ) diff --git a/tutorials/examples/train_hypergrid_buffer.py b/tutorials/examples/train_hypergrid_buffer.py index 8ecb1070..ebf6b71e 100644 --- a/tutorials/examples/train_hypergrid_buffer.py +++ b/tutorials/examples/train_hypergrid_buffer.py @@ -129,7 +129,7 @@ def main(args): bwd_trajectories = backward_sampler.sample_trajectories( env, states=terminating_states, - save_logprobs=False, # TODO: enable this + save_logprobs=True, save_estimator_outputs=False, # TODO: log rewards, conditioning, ... ) From fe7766176deb35a2eb0cb585ec0704a8f945b668 Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Thu, 25 Sep 2025 17:16:31 +0100 Subject: [PATCH 2/4] support reuse of precomputed rewards for backward sampling --- src/gfn/samplers.py | 8 ++++++++ tutorials/examples/train_hypergrid_buffer.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 58b72a22..384bab68 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -115,6 +115,7 @@ def sample_trajectories( env: Env, n: Optional[int] = None, states: Optional[States] = None, + log_rewards: Optional[torch.Tensor] = None, conditioning: Optional[torch.Tensor] = None, save_estimator_outputs: bool = False, save_logprobs: bool = False, @@ -135,6 +136,8 @@ def sample_trajectories( states: Initial states to start trajectories from. It should have batch_shape of length 1 (no trajectory dim). If `None`, `n` must be provided and we initialize `n` trajectories with the environment's initial state. + log_rewards: Optional tensor of log rewards for backward sampling. If None, + log rewards are computed on the fly. conditioning: Optional tensor of conditioning information for conditional policies. Must match the batch shape of states. save_estimator_outputs: If True, saves the estimator outputs for each @@ -161,6 +164,10 @@ def sample_trajectories( ), "When backward sampling, `states` must be provided" # assert states in env.terminating_states # This assert would be useful, # unfortunately, not every environment implements this. + + # Compute log rewards on the fly if not provided. + if log_rewards is None: + log_rewards = env.log_reward(states) else: if states is None: assert n is not None, "Either kwarg `states` or `n` must be specified" @@ -417,6 +424,7 @@ def local_search( prev_trajectories = self.backward_sampler.sample_trajectories( env, states=trajectories.terminating_states, + log_rewards=trajectories.log_rewards, conditioning=conditioning, save_estimator_outputs=save_estimator_outputs, save_logprobs=save_logprobs, diff --git a/tutorials/examples/train_hypergrid_buffer.py b/tutorials/examples/train_hypergrid_buffer.py index ebf6b71e..c23b49c6 100644 --- a/tutorials/examples/train_hypergrid_buffer.py +++ b/tutorials/examples/train_hypergrid_buffer.py @@ -129,12 +129,12 @@ def main(args): bwd_trajectories = backward_sampler.sample_trajectories( env, states=terminating_states, + log_rewards=terminating_states_container.log_rewards, save_logprobs=True, save_estimator_outputs=False, # TODO: log rewards, conditioning, ... ) buffer_trajectories = bwd_trajectories.reverse_backward_trajectories() - buffer_trajectories._log_rewards = terminating_states_container.log_rewards optimizer.zero_grad() loss = gflownet.loss(env, buffer_trajectories, recalculate_all_logprobs=True) From 05539e2c5b7b4bb324711eedb772ef916fade778 Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Thu, 25 Sep 2025 17:17:21 +0100 Subject: [PATCH 3/4] remove is_backward flag from Transitions since we don't actually support it --- src/gfn/containers/trajectories.py | 4 +++- src/gfn/containers/transitions.py | 13 ------------- src/gfn/gflownet/detailed_balance.py | 6 ------ src/gfn/utils/prob_calculations.py | 3 --- 4 files changed, 3 insertions(+), 23 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index f69a5702..a677ed83 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -471,6 +471,9 @@ def to_transitions(self) -> Transitions: A Transitions object with the same states, actions, and log_rewards as the current Trajectories. """ + if self.is_backward: + return self.reverse_backward_trajectories().to_transitions() + if self.conditioning is not None: # The conditioning tensor has shape (max_length, n_trajectories, 1) # The actions have shape (max_length, n_trajectories) @@ -522,7 +525,6 @@ def to_transitions(self) -> Transitions: actions=actions, is_terminating=is_terminating, next_states=next_states, - is_backward=self.is_backward, log_rewards=log_rewards, log_probs=log_probs, backward_log_probs=backward_log_probs, diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index e8aabf62..e5c7db3c 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -29,10 +29,6 @@ class Transitions(Container): is_terminating: Boolean tensor of shape (n_transitions,) indicating whether the action is the exit action. next_states: States with batch_shape (n_transitions,). - is_backward: Whether the transitions are backward transitions. When not - is_backward, the `states` are the parents of the transitions and the - `next_states` are the children. When is_backward, the `states` are the - children of the transitions and the `next_states` are the parents. _log_rewards: (Optional) Tensor of shape (n_transitions,) containing the log rewards of the transitions. log_probs: (Optional) Tensor of shape (n_transitions,) containing the log @@ -49,7 +45,6 @@ def __init__( actions: Actions | None = None, is_terminating: torch.Tensor | None = None, next_states: States | None = None, - is_backward: bool = False, log_rewards: torch.Tensor | None = None, log_probs: torch.Tensor | None = None, backward_log_probs: torch.Tensor | None = None, @@ -68,7 +63,6 @@ def __init__( the action is the exit action. next_states: States with batch_shape (n_transitions,). If None, an empty States object is created. - is_backward: Whether the transitions are backward transitions. log_rewards: Optional tensor of shape (n_transitions,) containing the log rewards for the transitions. If None, computed on the fly when needed. log_probs: Optional tensor of shape (n_transitions,) containing the log @@ -82,7 +76,6 @@ def __init__( an empty container that can be populated later with the `extend` method. """ self.env = env - self.is_backward = is_backward # Assert that all tensors are on the same device as the environment. device = self.env.device @@ -205,9 +198,6 @@ def log_rewards(self) -> torch.Tensor | None: If not provided at initialization, log rewards are computed on demand for terminating transitions. """ - if self.is_backward: - return None - if self._log_rewards is None: self._log_rewards = torch.full( (self.n_transitions,), @@ -233,8 +223,6 @@ def all_log_rewards(self) -> torch.Tensor: Log rewards tensor of shape (n_transitions, 2) for the transitions. """ # TODO: reuse self._log_rewards if it exists. - if self.is_backward: - raise NotImplementedError("Not implemented for backward transitions") is_sink_state = self.next_states.is_sink_state log_rewards = torch.full( (self.n_transitions, 2), @@ -287,7 +275,6 @@ def __getitem__( actions=actions, is_terminating=is_terminating, next_states=next_states, - is_backward=self.is_backward, log_rewards=log_rewards, log_probs=log_probs, backward_log_probs=backward_log_probs, diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index efb53d1b..b77b01a7 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -168,9 +168,6 @@ def get_scores( A tensor of shape (n_transitions,) representing the scores for each transition. """ - if transitions.is_backward: - raise ValueError("Backward transitions are not supported") - states = transitions.states actions = transitions.actions @@ -328,9 +325,6 @@ def get_scores( Returns: A tensor of shape (n_transitions,) containing the scores for each transition. """ - if transitions.is_backward: - raise ValueError("Backward transitions are not supported") - if len(transitions) == 0: return torch.tensor(0.0, device=transitions.device) diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 2a28a861..fdcf42cf 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -259,9 +259,6 @@ def get_transition_pfs_and_pbs( Raises: ValueError: If backward transitions are provided. """ - if transitions.is_backward: - raise ValueError("Backward transitions are not supported") - log_pf_transitions = get_transition_pfs(pf, transitions, recalculate_all_logprobs) log_pb_transitions = get_transition_pbs(pb, transitions) From a12ce72d64cee026595e8408981070d6de93125d Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Thu, 25 Sep 2025 18:09:22 +0100 Subject: [PATCH 4/4] fix test --- src/gfn/samplers.py | 64 +++++++++++++++----- tutorials/examples/train_hypergrid_buffer.py | 2 +- 2 files changed, 50 insertions(+), 16 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 384bab68..84364a5d 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -38,6 +38,27 @@ def __init__(self, estimator: Estimator) -> None: self.estimator = estimator self.is_backward = estimator.is_backward + def get_estimator_output( + self, states: States, conditioning: torch.Tensor | None = None + ) -> torch.Tensor: + """Gets the estimator output for the given states and conditioning. + + Args: + states: The states to get the estimator output for. + conditioning: The conditioning to get the estimator output for. + + Returns: + The estimator output for the given states and conditioning. + """ + # TODO: Should estimators instead ignore None for the conditioning vector? + if conditioning is not None: + with has_conditioning_exception_handler("estimator", self.estimator): + estimator_output = self.estimator(states, conditioning) + else: + with no_conditioning_exception_handler("estimator", self.estimator): + estimator_output = self.estimator(states) + return estimator_output + def sample_actions( self, env: Env, @@ -78,13 +99,7 @@ def sample_actions( - Optional tensor of log probabilities (if save_logprobs=True) - Optional tensor of estimator outputs (if save_estimator_outputs=True) """ - # TODO: Should estimators instead ignore None for the conditioning vector? - if conditioning is not None: - with has_conditioning_exception_handler("estimator", self.estimator): - estimator_output = self.estimator(states, conditioning) - else: - with no_conditioning_exception_handler("estimator", self.estimator): - estimator_output = self.estimator(states) + estimator_output = self.get_estimator_output(states, conditioning) dist = self.estimator.to_probability_distribution( states, estimator_output, **policy_kwargs @@ -274,14 +289,33 @@ def sample_trajectories( # Stack all states and actions stacked_states = env.States.stack(trajectories_states) - stacked_actions = env.Actions.stack(trajectories_actions) - stacked_logprobs = ( - torch.stack(trajectories_logprobs, dim=0) if save_logprobs else None - ) - # TODO: use torch.nested.nested_tensor(dtype, device, requires_grad). - stacked_estimator_outputs = ( - torch.stack(all_estimator_outputs, dim=0) if save_estimator_outputs else None - ) + + if len(trajectories_actions) > 0: + stacked_actions = env.Actions.stack(trajectories_actions) + stacked_logprobs = ( + torch.stack(trajectories_logprobs, dim=0) if save_logprobs else None + ) + stacked_estimator_outputs = ( + torch.stack(all_estimator_outputs, dim=0) + if save_estimator_outputs + else None + ) + else: # len(trajectories_actions) == 0 + # This can happen when we sample forward (backward) and the given + # `states` are all sink (initial) states. + # In this case, we need to create a dummy tensor with the correct shape. + stacked_actions = env.actions_from_batch_shape((0, n_trajectories)) + stacked_logprobs = torch.zeros((0, n_trajectories), device=device) + if save_estimator_outputs: + # We need to check the shape of the estimator outputs to create a dummy tensor with the correct shape. + dummy_estimator_output = self.get_estimator_output( + env.states_from_batch_shape((n_trajectories,)), conditioning + ) + stacked_estimator_outputs = torch.zeros( + (0, n_trajectories, *dummy_estimator_output.shape[1:]), device=device + ) + else: + stacked_estimator_outputs = None # Broadcast conditioning tensor to match states batch shape if needed if conditioning is not None: diff --git a/tutorials/examples/train_hypergrid_buffer.py b/tutorials/examples/train_hypergrid_buffer.py index c23b49c6..e9d0a212 100644 --- a/tutorials/examples/train_hypergrid_buffer.py +++ b/tutorials/examples/train_hypergrid_buffer.py @@ -137,7 +137,7 @@ def main(args): buffer_trajectories = bwd_trajectories.reverse_backward_trajectories() optimizer.zero_grad() - loss = gflownet.loss(env, buffer_trajectories, recalculate_all_logprobs=True) + loss = gflownet.loss(env, buffer_trajectories, recalculate_all_logprobs=False) loss.backward() optimizer.step()