Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs/source/guides/creating_environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,7 @@ the following method:
contains state-dependent forward and backward masks, which define allowable
forward and backward actions conditioned on the state. Note that in
calculating these masks, the user can leverage the helper methods
`DiscreteStates.set_nonexit_action_masks`,
`DiscreteStates.set_exit_masks`, and
`DiscreteStates.init_forward_masks`.
`DiscreteStates.set_nonexit_action_masks`, and `DiscreteStates.set_exit_masks`.

The code automatically implements the following two class factories, which the
majority of users will not need to overwrite. However, the user could override
Expand Down
4 changes: 1 addition & 3 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def add(self, training_container: ContainerUnion):
training_container = training_container[idx_bigger_rewards]

# TODO: Concatenate input with final state for conditional GFN.
if training_container.conditions:
if training_container.states.conditions is not None:
raise NotImplementedError(
"{instance.__class__.__name__} does not yet support conditional GFNs."
)
Expand Down Expand Up @@ -461,13 +461,11 @@ def add(self, training_container: ContainerUnion):
raise TypeError("Must be a StatesContainer")

terminating_states = training_container.terminating_states
conditions = training_container.conditions
log_rewards = training_container.log_rewards

terminating_states_container = StatesContainer(
env=self.env,
states=terminating_states,
conditions=conditions,
is_terminating=torch.ones(
len(terminating_states), dtype=torch.bool, device=self.env.device
),
Expand Down
42 changes: 1 addition & 41 deletions src/gfn/containers/states_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ class StatesContainer(Container, Generic[StateType]):
Attributes:
env: The environment where the states are defined.
states: States with batch_shape (n_states,).
conditions: (Optional) Tensor of shape (n_states,) containing the conditions
for the states.
is_terminating: Boolean tensor of shape (n_states,) indicating which states
are terminating.
_log_rewards: (Optional) Tensor of shape (n_states,) containing the log rewards
Expand All @@ -35,7 +33,6 @@ def __init__(
self,
env: Env,
states: StateType | None = None,
conditions: torch.Tensor | None = None,
is_terminating: torch.Tensor | None = None,
log_rewards: torch.Tensor | None = None,
):
Expand All @@ -45,8 +42,6 @@ def __init__(
env: The environment where the states are defined.
states: States with batch_shape (n_states,). If None, an empty batch is
created.
conditions: Optional tensor of shape (n_states,) containing the conditions
for the states.
is_terminating: Boolean tensor of shape (n_states,) indicating which states
are terminating. If None, all are set to False.
log_rewards: Optional tensor of shape (n_states,) containing the log rewards
Expand All @@ -58,7 +53,7 @@ def __init__(
device = self.env.device
if states is not None:
ensure_same_device(states.device, device)
for tensor in [is_terminating, conditions, log_rewards]:
for tensor in [is_terminating, log_rewards]:
ensure_same_device(tensor.device, device) if tensor is not None else True

self.states = (
Expand All @@ -69,11 +64,6 @@ def __init__(
assert len(self.states.batch_shape) == 1
batch_shape = self.states.batch_shape

self.conditions = conditions
assert self.conditions is None or (
self.conditions.shape[: len(batch_shape)] == batch_shape
)

self.is_terminating = (
is_terminating
if is_terminating is not None
Expand Down Expand Up @@ -119,28 +109,6 @@ def terminating_states(self) -> StateType:
"""
return cast(StateType, self.states[self.is_terminating])

@property
def intermediary_conditions(self) -> torch.Tensor | None:
"""Conditions for intermediary states.

Returns:
The conditions tensor for intermediary states, or None if not set.
"""
if self.conditions is None:
return None
return self.conditions[~self.states.is_initial_state]

@property
def terminating_conditions(self) -> torch.Tensor | None:
"""Conditions for terminating states.

Returns:
The conditions tensor for terminating states, or None if not set.
"""
if self.conditions is None:
return None
return self.conditions[self.is_terminating]

def __len__(self) -> int:
"""Returns the number of states in the container.

Expand Down Expand Up @@ -219,12 +187,6 @@ def extend(self, other: StatesContainer[StateType]) -> None:
(self.is_terminating, other.is_terminating), dim=0
)

# Concatenate conditions tensors if they exist.
if self.conditions is not None and other.conditions is not None:
self.conditions = torch.cat((self.conditions, other.conditions), dim=0)
else:
self.conditions = None

# Concatenate log_rewards of the trajectories if they exist.
if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat((self._log_rewards, other._log_rewards), dim=0)
Expand All @@ -248,14 +210,12 @@ def __getitem__(
# Cast the indexed states to maintain their type
states = cast(StateType, self.states[index])
is_terminating = self.is_terminating[index]
conditions = self.conditions[index] if self.conditions is not None else None
log_rewards = self._log_rewards[index] if self._log_rewards is not None else None

# We can construct a new StatesContainer with the same StateType
return StatesContainer[StateType](
env=self.env,
states=states,
conditions=conditions,
is_terminating=is_terminating,
log_rewards=log_rewards,
)
84 changes: 16 additions & 68 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class Trajectories(Container):
Attributes:
env: The environment where the states and actions are defined.
states: States with batch_shape (max_length+1, n_trajectories).
conditions: (Optional) Tensor of shape (n_trajectories, condition_vector_dim)
containing the condition vectors for the trajectories.
actions: Actions with batch_shape (max_length, n_trajectories).
terminating_idx: Tensor of shape (n_trajectories,) indicating the time step
at which each trajectory ends.
Expand All @@ -49,7 +47,6 @@ def __init__(
self,
env: Env,
states: States | None = None,
conditions: torch.Tensor | None = None,
actions: Actions | None = None,
terminating_idx: torch.Tensor | None = None,
is_backward: bool = False,
Expand All @@ -63,8 +60,6 @@ def __init__(
env: The environment where the states and actions are defined.
states: States with batch_shape (max_length+1, n_trajectories). If None,
an empty States object is created.
conditions: Optional tensor of shape (n_trajectories, condition_vector_dim)
containing the condition vectors for each trajectory.
actions: Actions with batch_shape (max_length, n_trajectories). If None,
an empty Actions object is created.
terminating_idx: Tensor of shape (n_trajectories,) indicating the time step
Expand Down Expand Up @@ -93,7 +88,6 @@ def __init__(
ensure_same_device(obj.device, device)

for tensor in [
conditions,
terminating_idx,
log_rewards,
log_probs,
Expand All @@ -107,12 +101,6 @@ def __init__(
)
assert len(self.states.batch_shape) == 2

self.conditions = conditions
assert self.conditions is None or (
self.conditions.shape[: len(self.states.batch_shape)]
== self.states.batch_shape
)

self.actions = (
actions if actions is not None else env.actions_from_batch_shape((0, 0))
)
Expand Down Expand Up @@ -266,7 +254,6 @@ def __getitem__(
terminating_idx = self.terminating_idx[index]
new_max_length = terminating_idx.max().item() if len(terminating_idx) > 0 else 0
states = self.states[:, index]
conditions = self.conditions[:, index] if self.conditions is not None else None
actions = self.actions[:, index]
states = states[: 1 + new_max_length]
actions = actions[:new_max_length]
Expand Down Expand Up @@ -295,7 +282,6 @@ def __getitem__(
return Trajectories(
env=self.env,
states=states,
conditions=conditions,
actions=actions,
terminating_idx=terminating_idx,
is_backward=self.is_backward,
Expand All @@ -311,14 +297,8 @@ def extend(self, other: Trajectories) -> None:
log_rewards).

Args:
Another Trajectories to append.
other: Another Trajectories to append.
"""
if self.conditions is not None:
# TODO: Support the case
raise NotImplementedError(
"`extend` is not implemented for conditional Trajectories."
)

if len(other) == 0:
return

Expand Down Expand Up @@ -387,18 +367,11 @@ def to_transitions(self) -> Transitions:
A Transitions object with the same states, actions, and log_rewards as the
current Trajectories.
"""
if self.conditions is not None:
# The conditions tensor has shape (max_length, n_trajectories, 1)
# The actions have shape (max_length, n_trajectories)
# We need to index the conditions tensor to match the actions
# The actions exclude the last step, so we need to exclude the last step from conditions
conditions = self.conditions[:-1][~self.actions.is_dummy]
else:
conditions = None
valid_action_mask = ~self.actions.is_dummy

states = self.states[:-1][~self.actions.is_dummy]
next_states = self.states[1:][~self.actions.is_dummy]
actions = self.actions[~self.actions.is_dummy]
states = self.states[:-1][valid_action_mask]
next_states = self.states[1:][valid_action_mask]
actions = self.actions[valid_action_mask]
is_terminating = (
next_states.is_sink_state
if not self.is_backward
Expand Down Expand Up @@ -429,7 +402,6 @@ def to_transitions(self) -> Transitions:
return Transitions(
env=self.env,
states=states,
conditions=conditions,
actions=actions,
is_terminating=is_terminating,
next_states=next_states,
Expand All @@ -454,56 +426,27 @@ def to_states_container(self) -> StatesContainer:
)
is_terminating[self.terminating_idx - 1, torch.arange(len(self))] = True

states = self.states.flatten()
is_terminating = is_terminating.flatten()

states = self.states
is_valid = ~states.is_sink_state & (
~states.is_initial_state | (states.is_initial_state & is_terminating)
)
states = states[is_valid]
is_terminating = is_terminating[is_valid]

conditions = None
if self.conditions is not None:
# The conditions tensor has shape (max_length, n_trajectories, 1)
# We need to flatten it to match the flattened states
# First, we need to repeat it to match the flattened shape
# The flattened states have shape (max_length * n_trajectories,)
# So we need to repeat the conditions tensor accordingly
conditions = self.conditions.flatten(0, 1)[is_valid]

if self.log_rewards is None:
log_rewards = None
else:
log_rewards = torch.full(
(len(states),),
fill_value=-float("inf"),
device=states.device,
self.states.batch_shape, fill_value=-float("inf"), device=states.device
)
# Get the original indices (before flattening and filtering).
orig_batch_indices = torch.arange(
self.states.batch_shape[0], device=states.device
).repeat_interleave(self.states.batch_shape[1])
orig_traj_indices = torch.arange(
self.states.batch_shape[1], device=states.device
).repeat(self.states.batch_shape[0])

# Retain only the valid indices.
valid_batch_indices = orig_batch_indices[is_valid]
valid_traj_indices = orig_traj_indices[is_valid]

# Assign rewards to valid terminating states.
terminating_mask = is_terminating & (
valid_batch_indices == (self.terminating_idx[valid_traj_indices] - 1)
log_rewards[self.terminating_idx - 1, torch.arange(len(self))] = (
self.log_rewards
)
log_rewards[terminating_mask] = self.log_rewards[
valid_traj_indices[terminating_mask]
]
log_rewards = log_rewards[is_valid]

return StatesContainer[DiscreteStates](
env=self.env,
states=states,
conditions=conditions,
is_terminating=is_terminating,
log_rewards=log_rewards,
# FIXME: Add log_probs and estimator_outputs.
Expand Down Expand Up @@ -585,10 +528,15 @@ def reverse_backward_trajectories(self) -> Trajectories:
# new_states: (max_len + 2, n_trajectories, *state_dim)
# ---------------------------------------------------------------------

# Add conditions to the new states
if self.states.conditions is not None:
new_states.conditions = torch.cat(
(self.states.conditions[[0]], self.states.conditions), dim=0
)

reversed_trajectories = Trajectories(
env=self.env,
states=new_states,
conditions=self.conditions,
actions=new_actions,
terminating_idx=self.terminating_idx + 1,
is_backward=False,
Expand Down
Loading