Skip to content

Conversation

@alexandrelarouche
Copy link
Contributor

@alexandrelarouche alexandrelarouche commented Dec 11, 2024

I'm hoping to get your insight on how to make this better. Some parts of the code are sketchy and have been highlighted with a WARNING tag in the comments.

Add function to generate trajectories from states and actions tensors
Add function to crudely warmup a GFN (early stopping or other tricks not included)
…ging since every other GFN loss method returns tensor
@saleml
Copy link
Collaborator

saleml commented Jan 11, 2025

Thank you for the PR. Could you please elaborate a little bit more on it? What use-case are you targeting? Where do you use the new functions? Is there a way to test them and see their effects in the repo?

Thanks

@alexandrelarouche
Copy link
Contributor Author

alexandrelarouche commented Jan 12, 2025

Hi Salem!

Yes, sorry, I contacted Joseph via Slack prior to the PR, but I should've given more detail on here.

These functions are provided as a means to generate warmup trajectories from external state-action-tensors (e.g.\ expert knowledge, or another algorithm's output). My rationale for PR'ing these simple functions is that I found the whole process to be non-trivial when looking at the sources/docs (namely, watch for the WARNING tags) and I thought other users could benefit from having either a full implementation or an example.

def states_actions_tns_to_traj(
    states_tns: torch.Tensor,
    actions_tns: torch.Tensor,
    env: DiscreteEnv,
) -> Trajectories:

is a utility function that maps state-tensors and actions to a Trajectories object. Effectively, this is a translation function for a prior that comes from outside of the torch-gfn ecosystem, that would not already be wrapped in a Trajectories.

def warm_up(
    replay_buf: ReplayBuffer,
    optimizer: torch.optim.Optimizer,
    gfn: GFlowNet,
    env: Env,
    n_steps: int,
    batch_size: int,
    recalculate_all_logprobs=True,
):

is a training loop over a fixed replay buffer, but does not assume that some log-probs were computed in the Trajectories generated by the prior. Anyone could implement their own version, I simply provided mine as a crude example (there is no early stopping here, or any training trick). The important/tricky bit, if I remember correctly lies in settings recalculate_all_logprobs=True for TB-GFNs, because the states_actions_tns_to_traj function creates some dummy log prob tensors (since it is not expected that the prior would provide those).

I can write some unit-tests for the states_actions_tns_to_traj, as I think it is the trickier function in this duo. I can also create the docstrings (which I thought I had provided, my bad).

If you have any other feedback, send it my way so that we can implement it and follow your philosophy more closely.

Edit: I clarified why the warm-up function was important to this PR

@saleml
Copy link
Collaborator

saleml commented Jan 22, 2025

Thank you for the PR
The states_actions_tns_to_traj function needs better input validation and documentation. Here's how I would modify it:

    if states_tns.shape[1:] != env.state_shape:
        raise ValueError(
            f"states_tns state dimensions must match env.state_shape {env.state_shape}, "
            f"got shape {states_tns.shape[1:]}"
        )
    if len(actions_tns.shape) != 1:
        raise ValueError(f"actions_tns must be 1D, got batch_shape {actions_tns.shape}")
    if states_tns.shape[0] != actions_tns.shape[0]:
        raise ValueError(
            f"states and actions must have same trajectory length, got "
            f"states: {states_tns.shape[0]}, actions: {actions_tns.shape[0]}"
        )

    # ... rest of the code ...

Possible docstrign to add:

   
   This utility function helps integrate external data (e.g. expert demonstrations) 
   into the GFlowNet framework by converting raw tensors into proper Trajectories objects.
   
   Args:
       states_tns: Tensor of shape [traj_len, *state_shape] containing states for a single trajectory
       actions_tns: Tensor of shape [traj_len] containing discrete action indices
       env: The discrete environment that defines the state/action spaces
       
   Returns:
       Trajectories: A Trajectories object containing the converted states and actions
       
   Raises:
       ValueError: If tensor shapes are invalid or inconsistent
   """

For the warm_up, a docstring would be appreciated. I am not sure why gfn.loss admits an extra argument for TBloss. I will investigate it.

for epoch in t:
training_trajs = replay_buf.sample(batch_size)
optimizer.zero_grad()
if isinstance(gfn, TBGFlowNet):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with #231 , this could be changed to a cleaner test (if it's a PFBasedGFlowNet)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Seeing your commit, I think this would be cleaner.

Add doscrings
Add input validation (as proposed by saleml)
Add PFBasedGFlowNet verification instead of only TBGFNs (needs merge
of GFNOrg#231)
@josephdviviano josephdviviano self-assigned this Jan 24, 2025
Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, first I want to apologize for taking so long to review this. I hit a bit of a lull over Dec / early Jan and have been playing catchup.

This is a really nice PR, and a feature I'd be excited to use myself in some of the applications I've been looking at. My only request revolves around the use of the dummy log_probs - if our library is working properly, it should function as intended using log_probs=None, and if not, we should fix the downstream elements if they're misbehaving, because this is the intended use of the Trajectories container.

Awesome contribution, thank you very much!

@alexandrelarouche
Copy link
Contributor Author

alexandrelarouche commented Jan 29, 2025

@josephdviviano After testing by sending log_probs = None, it seems like it is working. However, on any Trajectories.extend() call, I get the following error:

Traceback (most recent call last):
  File "/Users/quoding/Documents/PhD/gfn-explain/scripts/experiments/knapsack/basic_gfn.py", line 225, in main
    replay_buf.add(trajectories)
  File "/Users/quoding/.pyenv/versions/gfn_tf/lib/python3.10/site-packages/gfn/containers/replay_buffer.py", line 177, in add
    self._add_objs(training_objects)
  File "/Users/quoding/.pyenv/versions/gfn_tf/lib/python3.10/site-packages/gfn/containers/replay_buffer.py", line 149, in _add_objs
    self.training_objects.extend(training_objects)
  File "/Users/quoding/.pyenv/versions/gfn_tf/lib/python3.10/site-packages/gfn/containers/trajectories.py", line 260, in extend
    assert self.log_probs.shape == self.actions.batch_shape
AssertionError

Therefore, it seems that doing things this way is problematic with replay buffers, among other things.
I noticed that the assert statement has a TODO comment over it saying it could be removed? Was that one of the reasons why?

Edit:
It seems like there is a difference between the instanciating process for log_probs = None and the expected shape when log_probs are given to the trajectory.

@josephdviviano
Copy link
Collaborator

josephdviviano commented Jan 29, 2025 via email

@saleml
Copy link
Collaborator

saleml commented Feb 8, 2025

Thank you @alexandrelarouche for the PR. I vote to merge it (once isort is fixed). Could you please run pre-commit run --all before committing? (you can skip the pytest checks locally, as they are long.

@josephdviviano
Copy link
Collaborator

@alexandrelarouche can you post a minimal example I can use to replicate your error?

@alexandrelarouche
Copy link
Contributor Author

alexandrelarouche commented Feb 11, 2025

Yes. Here is an MRE which fails on the log_rewards shape (instead of the log_probs, previously). It seems like the same chunk of code is responsible.

MRE:

from gfn.containers.replay_buffer import ReplayBuffer
from gfn.gym.hypergrid import HyperGrid
from gfn.utils.training import states_actions_tns_to_traj
import torch


if __name__ == "__main__":
    env = HyperGrid(2, 4)
    states = torch.tensor([[0, 0], [0, 1], [0, 2], [-1, -1]])
    actions = torch.tensor([1, 1, 2])
    replay_buffer = ReplayBuffer(env, "trajectories")
    trajs = states_actions_tns_to_traj(states, actions, env)

    replay_buffer.add(trajs)  # Errors happen here

Error:

Traceback (most recent call last):
  File "/Users/quoding/Documents/Code/torchgfn/scripts.py", line 16, in <module>
    replay_buffer.add(trajs)  # Errors happen here
    ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/quoding/Documents/Code/torchgfn/src/gfn/containers/replay_buffer.py", line 74, in add
    self.training_objects.extend(training_objects)
  File "/Users/quoding/Documents/Code/torchgfn/src/gfn/containers/trajectories.py", line 286, in extend
    assert len(self.log_rewards) == self.actions.batch_shape[-1]
AssertionError

Responsible code:

trajectories.py
L281
-------------------
        # Ensure log_probs/rewards are the correct dimensions. TODO: Remove?
        if self.log_probs.numel() > 0:
            assert self.log_probs.shape == self.actions.batch_shape

        if self.log_rewards is not None:
            assert len(self.log_rewards) == self.actions.batch_shape[-1]

@saleml
Copy link
Collaborator

saleml commented Feb 12, 2025

@alexandrelarouche I took the liberty to commit (what I hope is) the fix to your branch. Please feel free to revert the changes if you're not satisfied with my fix

Copy link
Contributor Author

@alexandrelarouche alexandrelarouche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems good overall, but I am confused as to why we would want to obscure code further with class method stack for states while there is a function for this. The only reason I used the class method for actions was the absence of a stack_actions equivalent to stack_states.

actions = actions[0].stack(actions)
log_rewards = env.log_reward(states[-2])
states = stack_states(states)
states = states[0].stack_states(states)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason I used the class method for actions was the absence of a stack_actions equivalent to stack_states. I think I would keep stack_states here, in other words.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We moved the function stack_states into the class, which is why Salem made this change I think. The import no longer exists after updating (I believe he pulled from main).

Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything seems to be working on my side, though I don't truly understand what the fix was, in the interest of time I'll approve and we can move on.

actions = actions[0].stack(actions)
log_rewards = env.log_reward(states[-2])
states = stack_states(states)
states = states[0].stack_states(states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We moved the function stack_states into the class, which is why Salem made this change I think. The import no longer exists after updating (I believe he pulled from main).

states_tns: torch.Tensor,
actions_tns: torch.Tensor,
env: DiscreteEnv,
conditioning: torch.Tensor | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@saleml was the fix simply to account for conditioning? Otherwise I don't see what you did, other than changing the function call to a method call.

@josephdviviano josephdviviano merged commit fb51133 into GFNOrg:master Feb 14, 2025
4 checks passed
@saleml
Copy link
Collaborator

saleml commented Feb 18, 2025

Hi @alexandrelarouche

Apologies as I should have explained earlier.

Basically the codebase changed from the moment you started the PR and the moment we merged it, so I needed to merge master into your branch. The remaining problem was the need to check that log_probs is not only not None, but also not "empty".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants