Skip to content

Conversation

@hyeok9855
Copy link
Collaborator

@hyeok9855 hyeok9855 commented Nov 4, 2025

  • I've read the .github/CONTRIBUTING.md file
  • My code follows the typing guidelines
  • I've added appropriate tests
  • I've run pre-commit hooks locally

Description

Major refactorings for conditional GFlowNets.

TODO (maybe in another PR?)

  • Let ConditionalEnv support conditional transitions

@hyeok9855 hyeok9855 marked this pull request as draft November 4, 2025 01:15
@hyeok9855 hyeok9855 self-assigned this Nov 4, 2025
@hyeok9855 hyeok9855 changed the title Refactor Conditional GFlowNets [WIP] Refactor Conditional GFlowNets Nov 4, 2025
@hyeok9855 hyeok9855 mentioned this pull request Nov 19, 2025
4 tasks
@hyeok9855 hyeok9855 marked this pull request as ready for review November 20, 2025 18:44
@hyeok9855 hyeok9855 changed the title [WIP] Refactor Conditional GFlowNets Refactor Conditional GFlowNets Nov 20, 2025
Copy link
Collaborator

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few comments; good to go for me, but I would wait for @josephdviviano as he understands this code better

Comment on lines 349 to 353
# Concatenate conditions of the trajectories.
if self.conditions is not None and other.conditions is not None:
self.conditions = torch.cat((self.conditions, other.conditions), dim=0)
else:
self.conditions = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we maybe add a test for extending with conditions, and then try common ops like get_item to check the output is as expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we maybe add a test for extending with conditions

I will add one.

and then try common ops like get_item to check the output is as expected?

I have no idea what this means. Could you elaborate more?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean in the test, after calling extend, check if the extend operation gave the expected result.
Like here:

pre_extend_shape = state2.batch_shape
state1.extend(state2)
assert state2.batch_shape == pre_extend_shape
# Check final shape should be (max_len=3, B=4)
assert state1.batch_shape == (3, 4)
# The actual count might be higher due to padding with sink states
assert state1.tensor.x.size(0) == expected_nodes
assert state1.tensor.num_edges == expected_edges
# Check if states are extended as expected
assert (state1[0, 0].tensor.x == datas[0].x).all()
assert (state1[0, 1].tensor.x == datas[1].x).all()
assert (state1[0, 2].tensor.x == datas[4].x).all()
assert (state1[0, 3].tensor.x == datas[5].x).all()
assert (state1[1, 0].tensor.x == datas[2].x).all()
assert (state1[1, 1].tensor.x == datas[3].x).all()
assert (state1[1, 2].tensor.x == datas[6].x).all()
assert (state1[1, 3].tensor.x == datas[7].x).all()
assert (state1[2, 0].tensor.x == MyGraphStates.sf.x).all()
assert (state1[2, 1].tensor.x == MyGraphStates.sf.x).all()
assert (state1[2, 2].tensor.x == datas[8].x).all()
assert (state1[2, 3].tensor.x == datas[9].x).all()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I will add a test soon!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are added, please check!

src/gfn/env.py Outdated
Comment on lines 452 to 464
def reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor:
"""Compute rewards for the conditional environment.
Args:
states: The states to compute rewards for.
states.tensor.shape should be (batch_size, *state_shape)
conditions: The conditions to compute rewards for.
conditions.shape should be (batch_size, condition_vector_dim)
Returns:
A tensor of shape (batch_size,) containing the rewards.
"""
raise NotImplementedError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha, this is not a real subclass of Env, as conditions are mandatory (i.e. if you can't call this function pretending it is an env obj while it is ConditionEnv).

Would it make sense to have a default condition?
If not, this shouldn't inehrit from Env probably.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have a default condition?

How could having a default condition solve the problem?

If not, this shouldn't inherit from Env probably.

Maybe, but still we need a parent class that defines the default methods for Envs, like reward, step, etc...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could having a default condition solve the problem?

If we have a function like this:

def get_reward(env: Env, states: States) -> torch.Tensor:
   return  env.reward(states)

This should work with any Env object, given the interface of Env.

However, currently, if I pass a ConditionEnv (which is an Env), this will fail as you need to specify the conditioning. If you have a default value for conditioning, now the get_reward function will work properly (indeed, with default, the reward function interface of ConditionEnv becomes a subtype of the one of Env)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative approach would be to have the conditions live inside the states themselves (states could have a conditioning field that is None unless conditioning is required, and then anything that accepts States follows a different path when conditioning is present).

The env itself would only be conditional or not depending on the logic the user defines in the reward and step functions. No actual ConditionalEnv class would be required.

The estimators would also optionally use the conditioning information, if it's present, just like how it's done currently.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I'm seeing something I never noticed before - this class makes the hot path for calculating from conditions tensor-based, which may or may not be more torch.compile friendly than using conditions in the states class.

The use of a ConditionalEnv is growing on me. I don't mind the changing API, but I would prefer if this logic was somehow all in the Env directly somehow. I keep changing my mind on the best design. I suppose it depends on whether we think putting the conditions in States is ultimately a good design.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative approach would be to have the conditions live inside the states themselves (states could have a conditioning field that is None unless conditioning is required, and then anything that accepts States follows a different path when conditioning is present).

Everything has been updated following this suggestion. Please review! Thanks.

Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall a really nice PR, but I have a few questions about changes that seem unrelated to the goal (in particular I think we remove a few checks that might have side effects not captured in our test suites) and I wonder if it would be cleaner for the conditioning to live directly within the States class which would help avoid a lot of added complexity. We can discuss in the standup. Great work!

src/gfn/env.py Outdated
Comment on lines 452 to 464
def reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor:
"""Compute rewards for the conditional environment.
Args:
states: The states to compute rewards for.
states.tensor.shape should be (batch_size, *state_shape)
conditions: The conditions to compute rewards for.
conditions.shape should be (batch_size, condition_vector_dim)
Returns:
A tensor of shape (batch_size,) containing the rewards.
"""
raise NotImplementedError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An alternative approach would be to have the conditions live inside the states themselves (states could have a conditioning field that is None unless conditioning is required, and then anything that accepts States follows a different path when conditioning is present).

The env itself would only be conditional or not depending on the logic the user defines in the reward and step functions. No actual ConditionalEnv class would be required.

The estimators would also optionally use the conditioning information, if it's present, just like how it's done currently.

Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I'll leave comments - we can decide what to do with the other PR before deciding what to do for this one.

But I must say there's a lot of good work here. Thank you, I'm sure much of this will be a good improvement to the library!

# We need to index the conditions tensor to match the actions
# The actions exclude the last step, so we need to exclude the last step from conditions
conditions = self.conditions[:-1][~self.actions.is_dummy]
# The conditions tensor has shape (n_trajectories, condition_vector_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is n_trajectories, batch_dim? That naming is a bit confusing because there's also trajectory_length.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, n_trajectories is actually batch_size. I agree that this is confusing. But i think it's better to fix this in a separate PR.

src/gfn/env.py Outdated
Comment on lines 452 to 464
def reward(self, states: States, conditions: torch.Tensor) -> torch.Tensor:
"""Compute rewards for the conditional environment.
Args:
states: The states to compute rewards for.
states.tensor.shape should be (batch_size, *state_shape)
conditions: The conditions to compute rewards for.
conditions.shape should be (batch_size, condition_vector_dim)
Returns:
A tensor of shape (batch_size,) containing the rewards.
"""
raise NotImplementedError
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I'm seeing something I never noticed before - this class makes the hot path for calculating from conditions tensor-based, which may or may not be more torch.compile friendly than using conditions in the states class.

The use of a ConditionalEnv is growing on me. I don't mind the changing API, but I would prefer if this logic was somehow all in the Env directly somehow. I keep changing my mind on the best design. I suppose it depends on whether we think putting the conditions in States is ultimately a good design.

if not env.is_discrete:
raise NotImplementedError(
"Flow Matching GFlowNet only supports discrete environments for now."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, it's handled here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, can't we use FM for GraphEnv?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason is we don't restrict graphs to being only Discrete, they can have Continuous attributes.

@codecov
Copy link

codecov bot commented Dec 14, 2025

Codecov Report

❌ Patch coverage is 76.54723% with 72 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.38%. Comparing base (91aae26) to head (e7a360b).
⚠️ Report is 25 commits behind head on master.

Files with missing lines Patch % Lines
src/gfn/states.py 72.41% 21 Missing and 11 partials ⚠️
src/gfn/gym/hypergrid.py 83.58% 3 Missing and 8 partials ⚠️
src/gfn/env.py 71.42% 9 Missing and 1 partial ⚠️
src/gfn/utils/prob_calculations.py 76.00% 5 Missing and 1 partial ⚠️
src/gfn/gym/bitSequence.py 25.00% 3 Missing ⚠️
src/gfn/gflownet/detailed_balance.py 80.00% 1 Missing and 1 partial ⚠️
src/gfn/gflownet/flow_matching.py 75.00% 1 Missing and 1 partial ⚠️
src/gfn/utils/training.py 33.33% 1 Missing and 1 partial ⚠️
src/gfn/containers/replay_buffer.py 0.00% 0 Missing and 1 partial ⚠️
src/gfn/gflownet/sub_trajectory_balance.py 85.71% 1 Missing ⚠️
... and 2 more
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #431      +/-   ##
==========================================
+ Coverage   74.23%   74.38%   +0.14%     
==========================================
  Files          47       47              
  Lines        6805     6891      +86     
  Branches      800      825      +25     
==========================================
+ Hits         5052     5126      +74     
- Misses       1449     1454       +5     
- Partials      304      311       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Collaborator

@josephdviviano josephdviviano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks really great @hyeok9855 thanks for your hard work on this

@josephdviviano josephdviviano merged commit 24dd86d into master Dec 18, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants