-
Notifications
You must be signed in to change notification settings - Fork 55
Add warm-up functionality with tensor to trajectory helper functions #224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add function to generate trajectories from states and actions tensors Add function to crudely warmup a GFN (early stopping or other tricks not included)
…ging since every other GFN loss method returns tensor
|
Thank you for the PR. Could you please elaborate a little bit more on it? What use-case are you targeting? Where do you use the new functions? Is there a way to test them and see their effects in the repo? Thanks |
|
Hi Salem! Yes, sorry, I contacted Joseph via Slack prior to the PR, but I should've given more detail on here. These functions are provided as a means to generate warmup trajectories from external state-action-tensors (e.g.\ expert knowledge, or another algorithm's output). My rationale for PR'ing these simple functions is that I found the whole process to be non-trivial when looking at the sources/docs (namely, watch for the def states_actions_tns_to_traj(
states_tns: torch.Tensor,
actions_tns: torch.Tensor,
env: DiscreteEnv,
) -> Trajectories:is a utility function that maps state-tensors and actions to a def warm_up(
replay_buf: ReplayBuffer,
optimizer: torch.optim.Optimizer,
gfn: GFlowNet,
env: Env,
n_steps: int,
batch_size: int,
recalculate_all_logprobs=True,
):is a training loop over a fixed replay buffer, but does not assume that some log-probs were computed in the I can write some unit-tests for the If you have any other feedback, send it my way so that we can implement it and follow your philosophy more closely. Edit: I clarified why the warm-up function was important to this PR |
|
Thank you for the PR Possible docstrign to add: For the warm_up, a docstring would be appreciated. I am not sure why |
src/gfn/utils/training.py
Outdated
| for epoch in t: | ||
| training_trajs = replay_buf.sample(batch_size) | ||
| optimizer.zero_grad() | ||
| if isinstance(gfn, TBGFlowNet): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with #231 , this could be changed to a cleaner test (if it's a PFBasedGFlowNet)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes! Seeing your commit, I think this would be cleaner.
Add doscrings Add input validation (as proposed by saleml) Add PFBasedGFlowNet verification instead of only TBGFNs (needs merge of GFNOrg#231)
josephdviviano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, first I want to apologize for taking so long to review this. I hit a bit of a lull over Dec / early Jan and have been playing catchup.
This is a really nice PR, and a feature I'd be excited to use myself in some of the applications I've been looking at. My only request revolves around the use of the dummy log_probs - if our library is working properly, it should function as intended using log_probs=None, and if not, we should fix the downstream elements if they're misbehaving, because this is the intended use of the Trajectories container.
Awesome contribution, thank you very much!
|
@josephdviviano After testing by sending Therefore, it seems that doing things this way is problematic with replay buffers, among other things. Edit: |
|
Ok great thanks for testing. We can patch up the downstream elements that
are complaining, this isn't expected behaviour.
Joseph (Mobile)
…On Wed, Jan 29, 2025 at 10:13 alexandrelarouche ***@***.***> wrote:
@josephdviviano <https://github.com/josephdviviano> After testing by
sending log_probs = None, it seems like it is working. However, on any
Trajectories.extend() call, I get the following error:
Traceback (most recent call last):
File "/Users/quoding/Documents/PhD/gfn-explain/scripts/experiments/knapsack/basic_gfn.py", line 225, in main
replay_buf.add(trajectories)
File "/Users/quoding/.pyenv/versions/gfn_tf/lib/python3.10/site-packages/gfn/containers/replay_buffer.py", line 177, in add
self._add_objs(training_objects)
File "/Users/quoding/.pyenv/versions/gfn_tf/lib/python3.10/site-packages/gfn/containers/replay_buffer.py", line 149, in _add_objs
self.training_objects.extend(training_objects)
File "/Users/quoding/.pyenv/versions/gfn_tf/lib/python3.10/site-packages/gfn/containers/trajectories.py", line 260, in extend
assert self.log_probs.shape == self.actions.batch_shape
AssertionError
Therefore, it seems that doing things this way is problematic with replay
buffers, among other things.
I noticed that the assert statement has a TODO comment over it saying it
could be remove? Was that one of the reasons why?
—
Reply to this email directly, view it on GitHub
<#224 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AA7TL2R4MZ57EKCX4VLTUQ32NDV2ZAVCNFSM6AAAAABTNXJD7GVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDMMRRHEZTGMBZHE>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
|
Thank you @alexandrelarouche for the PR. I vote to merge it (once isort is fixed). Could you please run |
|
@alexandrelarouche can you post a minimal example I can use to replicate your error? |
|
Yes. Here is an MRE which fails on the log_rewards shape (instead of the log_probs, previously). It seems like the same chunk of code is responsible. MRE: from gfn.containers.replay_buffer import ReplayBuffer
from gfn.gym.hypergrid import HyperGrid
from gfn.utils.training import states_actions_tns_to_traj
import torch
if __name__ == "__main__":
env = HyperGrid(2, 4)
states = torch.tensor([[0, 0], [0, 1], [0, 2], [-1, -1]])
actions = torch.tensor([1, 1, 2])
replay_buffer = ReplayBuffer(env, "trajectories")
trajs = states_actions_tns_to_traj(states, actions, env)
replay_buffer.add(trajs) # Errors happen hereError: Traceback (most recent call last):
File "/Users/quoding/Documents/Code/torchgfn/scripts.py", line 16, in <module>
replay_buffer.add(trajs) # Errors happen here
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/quoding/Documents/Code/torchgfn/src/gfn/containers/replay_buffer.py", line 74, in add
self.training_objects.extend(training_objects)
File "/Users/quoding/Documents/Code/torchgfn/src/gfn/containers/trajectories.py", line 286, in extend
assert len(self.log_rewards) == self.actions.batch_shape[-1]
AssertionErrorResponsible code: trajectories.py
L281
-------------------
# Ensure log_probs/rewards are the correct dimensions. TODO: Remove?
if self.log_probs.numel() > 0:
assert self.log_probs.shape == self.actions.batch_shape
if self.log_rewards is not None:
assert len(self.log_rewards) == self.actions.batch_shape[-1] |
|
@alexandrelarouche I took the liberty to commit (what I hope is) the fix to your branch. Please feel free to revert the changes if you're not satisfied with my fix |
alexandrelarouche
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems good overall, but I am confused as to why we would want to obscure code further with class method stack for states while there is a function for this. The only reason I used the class method for actions was the absence of a stack_actions equivalent to stack_states.
| actions = actions[0].stack(actions) | ||
| log_rewards = env.log_reward(states[-2]) | ||
| states = stack_states(states) | ||
| states = states[0].stack_states(states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only reason I used the class method for actions was the absence of a stack_actions equivalent to stack_states. I think I would keep stack_states here, in other words.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We moved the function stack_states into the class, which is why Salem made this change I think. The import no longer exists after updating (I believe he pulled from main).
josephdviviano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything seems to be working on my side, though I don't truly understand what the fix was, in the interest of time I'll approve and we can move on.
| actions = actions[0].stack(actions) | ||
| log_rewards = env.log_reward(states[-2]) | ||
| states = stack_states(states) | ||
| states = states[0].stack_states(states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We moved the function stack_states into the class, which is why Salem made this change I think. The import no longer exists after updating (I believe he pulled from main).
| states_tns: torch.Tensor, | ||
| actions_tns: torch.Tensor, | ||
| env: DiscreteEnv, | ||
| conditioning: torch.Tensor | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@saleml was the fix simply to account for conditioning? Otherwise I don't see what you did, other than changing the function call to a method call.
|
Apologies as I should have explained earlier. Basically the codebase changed from the moment you started the PR and the moment we merged it, so I needed to merge master into your branch. The remaining problem was the need to check that |
I'm hoping to get your insight on how to make this better. Some parts of the code are sketchy and have been highlighted with a WARNING tag in the comments.