Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d3148f7
added policy adaptors, factorized samplers to allow for modular adapt…
josephdviviano Sep 29, 2025
93a654a
fixed _SeqStates
josephdviviano Sep 29, 2025
33f96b1
Merge branch 'identity_preprocessor_remove_shape_checking' of github.…
josephdviviano Sep 29, 2025
14b110c
Update input_dim to use preprocessor output_dim
josephdviviano Sep 29, 2025
c7a3d8c
Update input_dim to use preprocessor's output_dim
josephdviviano Sep 29, 2025
788ff96
Merge branch 'master' of github.com:GFNOrg/torchgfn into generalize_s…
josephdviviano Sep 29, 2025
8eb98a5
Merge branch 'generalize_samplers' of github.com:GFNOrg/torchgfn into…
josephdviviano Sep 29, 2025
b3ed2bb
Merge branch 'generalize_samplers' of github.com:GFNOrg/torchgfn into…
josephdviviano Oct 3, 2025
a4fc53a
added draft of chunking logic -- need to test on some discrete enviro…
josephdviviano Oct 8, 2025
02856ee
added dtype casting to preprocessors
josephdviviano Oct 9, 2025
1224bc0
added vectorized and non-vectorized adapter-based probability calcula…
josephdviviano Oct 9, 2025
34202a8
added documentation
josephdviviano Oct 9, 2025
d1db3bd
removed strange change to documentation
josephdviviano Oct 9, 2025
08bf6eb
removed strange change to documentation
josephdviviano Oct 9, 2025
baa50e4
added basic recurrent bitsequence algorithm
josephdviviano Oct 9, 2025
e8d3fc2
added working bitsequence example for recurrent estimators and their …
josephdviviano Oct 10, 2025
e0dd464
fixed test
josephdviviano Oct 10, 2025
9d857ca
Merge branch 'generalize_samplers' of github.com:GFNOrg/torchgfn into…
josephdviviano Oct 10, 2025
496df98
Merge branch 'master' into chunking
josephdviviano Oct 10, 2025
bca3df6
Update estimators.py
josephdviviano Oct 12, 2025
fc6cb7a
black / isort
josephdviviano Oct 13, 2025
e2dc289
simplification of the contex, adapter logic, compression of documenta…
josephdviviano Oct 13, 2025
3c2862f
streamlined adapters under their own module
josephdviviano Oct 13, 2025
4a23ea0
typing
josephdviviano Oct 13, 2025
e2755e6
removed strict type ceck
josephdviviano Oct 13, 2025
ba6f0bd
shrank docs
josephdviviano Oct 13, 2025
db36953
added notes
josephdviviano Oct 13, 2025
3226660
removed finalize
josephdviviano Oct 13, 2025
4c2c1df
removed check_cond_forward
josephdviviano Oct 13, 2025
d066c97
removed record step
josephdviviano Oct 13, 2025
e638be9
lint errors
josephdviviano Oct 14, 2025
1ee6a8f
autoflake
josephdviviano Oct 14, 2025
6026008
minor formatting
josephdviviano Oct 14, 2025
aeec438
Merge pull request #413 from GFNOrg/make_adapters_part_of_estimators
josephdviviano Oct 14, 2025
f1e51c2
Merge branch 'generalize_samplers' of github.com:GFNOrg/torchgfn into…
josephdviviano Oct 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions docs/source/guides/estimator_adapters.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# PolicyMixin: Policies and Rollouts

Estimators become policy-capable by mixing in a small, uniform rollout API. This lets the same `Sampler` and probability utilities drive different estimator families (discrete, graph, conditional, recurrent) without bespoke glue code.

This guide explains:
- The Policy rollout API and `RolloutContext`
- Vectorized vs non‑vectorized probability paths
- How policies integrate with the `Sampler` and probability calculators
- How to implement a new policy mixin or tailor the default behavior

## Concepts and Goals

A policy‑capable estimator exposes:
- `is_vectorized: bool` — whether the estimator can be evaluated in a single vectorized call (no per‑step carry).
- `init_context(batch_size, device, conditioning)` — allocate a per‑rollout context.
- `compute_dist(states_active, ctx, step_mask, ...) -> (Distribution, ctx)` — run the model, build a `torch.distributions.Distribution`.
- `log_probs(actions_active, dist, ctx, step_mask, vectorized, ...) -> (Tensor, ctx)` — evaluate log‑probs, optionally padded to batch.
- `get_current_estimator_output(ctx)` — access the last raw model output when requested.

All per‑step artifacts (e.g., log‑probs, raw outputs, recurrent state) are owned by the `RolloutContext` and recorded by the mixin.

## RolloutContext

The `RolloutContext` is a lightweight container created once per rollout:
- `batch_size`, `device`, optional `conditioning`
- Optional `carry` (for recurrent policies)
- Per‑step buffers: `trajectory_log_probs`, `trajectory_estimator_outputs`
- `current_estimator_output` for cached reuse or immediate retrieval
- `extras: dict` for arbitrary policy‑specific data

See `src/gfn/estimators.py` for the full definition.

## PolicyMixin (vectorized, default)

`PolicyMixin` enables vectorized evaluation by default (`is_vectorized=True`).

- `init_context(batch_size, device, conditioning)` returns a fresh `RolloutContext` with empty buffers.
- `compute_dist(...)`:
- Slices `conditioning` by `step_mask` when provided; uses full `conditioning` when `step_mask=None` (vectorized).
- Optionally reuses `ctx.current_estimator_output` (e.g., PF with cached `trajectories.estimator_outputs`).
- Calls the estimator module and builds a `Distribution` via `to_probability_distribution`.
- When `save_estimator_outputs=True`, sets `ctx.current_estimator_output` and records a padded copy to `ctx.trajectory_estimator_outputs` for non‑vectorized calls.
- `log_probs(...)`:
- `vectorized=True`: returns raw `dist.log_prob(...)` (may include `-inf` for illegal actions) and optionally records to `trajectory_log_probs`.
- `vectorized=False`: strict inf‑check, pads to shape `(N,)` using `step_mask`, records when requested.

Code reference (log‑probs behavior): `src/gfn/estimators.py`.

## RecurrentPolicyMixin (per‑step)

`RecurrentPolicyMixin` sets `is_vectorized=False` and threads a carry through steps:

- `init_context(...)` requires the estimator to implement `init_carry(batch_size, device)`; stores the result in `ctx.carry`.
- `compute_dist(...)` must call the estimator as `(states_active, ctx.carry) -> (est_out, new_carry)`, update `ctx.carry`, build the `Distribution`, and record outputs when requested (with padding when masked).
- `log_probs(...)` follows the non‑vectorized path (pad and strict checks) and can reuse the same recording semantics as `PolicyMixin`.

Code reference (carry update and padded recording): `src/gfn/estimators.py`.

## Integration with the Sampler

The `Sampler` uses the policy API directly. It creates a single `ctx` per rollout, then repeats `compute_dist` → sample → optional `log_probs` while some trajectories are active. Per‑step artifacts are recorded into `ctx` by the mixin when flags are enabled.

Excerpt (per‑step call pattern): `src/gfn/samplers.py`.

## Integration with probability calculators (PF/PB)

Probability utilities in `utils/prob_calculations.py` branch on `is_vectorized` but call the same two methods in both paths:
- `compute_dist(states_active, ctx, step_mask=None or mask)`
- `log_probs(actions_active, dist, ctx, step_mask=None or mask, vectorized=...)`

Key differences:
- Vectorized (fast path)
- `step_mask=None`, `vectorized=True`.
- May reuse cached estimator outputs by pre‑setting `ctx.current_estimator_output`.
- `log_probs` returns raw `dist.log_prob(...)` and does not raise on `-inf`.
- Non‑vectorized (per‑step path)
- Uses legacy‑accurate masks and alignments:
- PF (trajectories): `~states.is_sink_state[t] & ~actions.is_dummy[t]`
- PB (trajectories): aligns action at `t` with state at `t+1`, using `~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] & ~actions.is_dummy[t] & ~actions.is_exit[t]` (skips `t==0`).
- Transitions: legacy PB mask on `next_states` with `~actions.is_exit`.
- `log_probs` pads back to `(N,)` and raises if any `±inf` remains after masking.

See `src/gfn/utils/prob_calculations.py` for full branching.

## Built‑in policy‑capable estimators

- `DiscretePolicyEstimator`: logits → `Categorical` with masking, optional temperature and epsilon‑greedy mixing in log‑space.
- `DiscreteGraphPolicyEstimator`: multi‑head logits (`TensorDict`) → `GraphActionDistribution` with per‑component masks and transforms.
- `RecurrentDiscretePolicyEstimator`: sequence models that maintain a `carry`; requires `init_carry` and returns `(logits, carry)` in `forward`.
- Conditional variants exist for state+conditioning architectures.

## How to write a new policy (or mixin variant)

Most users only need to implement `to_probability_distribution` (or reuse the provided ones). If you need a new interface or extra tracking, you can either:

1) Use `PolicyMixin` (stateless, vectorized) and override `to_probability_distribution` on your estimator.
2) Use `RecurrentPolicyMixin` (per‑step, carry) and implement `init_carry` plus a `forward(states, carry)` that returns `(estimator_outputs, new_carry)`.
3) Create a custom mixin derived from `PolicyMixin` to tailor `compute_dist`/`log_probs` (e.g., custom caching, diagnostics).

### Minimal stateless policy (discrete)

```python
import torch
from torch import nn
from gfn.estimators import DiscretePolicyEstimator

class SmallMLP(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.input_dim = input_dim
self.net = nn.Sequential(
nn.Linear(input_dim, 128), nn.ReLU(), nn.Linear(128, output_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)

# Forward policy over n_actions
policy = DiscretePolicyEstimator(module=SmallMLP(input_dim=32, output_dim=17), n_actions=17)
```

Use with the `Sampler`:

```python
from gfn.samplers import Sampler

sampler = Sampler(policy)
trajectories = sampler.sample_trajectories(env, n=64, save_logprobs=True)
```

### Minimal recurrent policy

```python
import torch
from torch import nn
from gfn.estimators import RecurrentDiscretePolicyEstimator

class TinyRNN(nn.Module):
def __init__(self, vocab_size: int, hidden: int):
super().__init__()
self.vocab_size = vocab_size
self.embed = nn.Embedding(vocab_size, hidden)
self.rnn = nn.GRU(hidden, hidden, batch_first=True)
self.head = nn.Linear(hidden, vocab_size)

def forward(self, tokens: torch.Tensor, carry: dict[str, torch.Tensor]):
x = self.embed(tokens)
h0 = carry.get("h", torch.zeros(1, tokens.size(0), x.size(-1), device=tokens.device))
y, h = self.rnn(x, h0)
logits = self.head(y)
return logits, {"h": h}

def init_carry(self, batch_size: int, device: torch.device) -> dict[str, torch.Tensor]:
return {"h": torch.zeros(1, batch_size, self.embed.embedding_dim, device=device)}

policy = RecurrentDiscretePolicyEstimator(module=TinyRNN(vocab_size=33, hidden=64), n_actions=33)
```

### Custom mixin variant (advanced)

If you need to add diagnostics or custom caching, subclass `PolicyMixin` and override `compute_dist`/`log_probs` to interact with `ctx.extras`.

```python
from typing import Any, Optional
from torch.distributions import Distribution
from gfn.estimators import PolicyMixin

class TracingPolicyMixin(PolicyMixin):
def compute_dist(self, states_active, ctx, step_mask=None, save_estimator_outputs=False, **kw):
dist, ctx = super().compute_dist(states_active, ctx, step_mask, save_estimator_outputs, **kw)
ctx.extras.setdefault("num_compute_calls", 0)
ctx.extras["num_compute_calls"] += 1
return dist, ctx

def log_probs(self, actions_active, dist: Distribution, ctx: Any, step_mask=None, vectorized=False, save_logprobs=False):
lp, ctx = super().log_probs(actions_active, dist, ctx, step_mask, vectorized, save_logprobs)
ctx.extras.setdefault("last_lp_mean", lp.mean().detach())
return lp, ctx
```

Keep `is_vectorized` consistent with your evaluation strategy. If you switch to `False`, ensure your estimator supports per‑step rollouts and masking semantics.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
guides/example
guides/states_actions_containers
guides/modules_estimators_samplers
guides/estimator_adapters
guides/losses
guides/creating_environments
guides/advanced
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ tensordict = ">=0.6.1"
torch = ">=2.6.0"
torch_geometric = ">=2.6.1"
dill = ">=0.3.8"
tokenizers = ">=0.15"

# dev dependencies.
black = { version = "24.3", optional = true }
Expand Down
Empty file added src/gfn/chunking/__init__.py
Empty file.
130 changes: 130 additions & 0 deletions src/gfn/chunking/adapters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

from typing import Any, Optional

import torch
from torch.distributions import Categorical, Distribution

from gfn.chunking.policy import ChunkedPolicy
from gfn.env import DiscreteEnv
from gfn.samplers import AdapterContext, EstimatorAdapter
from gfn.states import DiscreteStates


class ChunkedAdapter(EstimatorAdapter):
"""EstimatorAdapter that produces macro-level distributions using ChunkedPolicy.

Forward-only in this PR. TODO(backward): support backward chunking by switching
stepping and termination criteria to the backward direction.
"""

def __init__(self, env: DiscreteEnv, policy: ChunkedPolicy, library: Any) -> None:
self.env = env
self.policy = policy
self.library = library
self._is_backward = False # TODO(backward): allow backward chunking

@property
def is_backward(self) -> bool:
return self._is_backward

def init_context(
self,
batch_size: int,
device: torch.device,
conditioning: Optional[torch.Tensor] = None,
) -> AdapterContext:
ctx = AdapterContext(
batch_size=batch_size, device=device, conditioning=conditioning
)
ctx.extras["macro_log_probs"] = [] # List[(N,)]
return ctx

def _strict_macro_mask(self, states_active: DiscreteStates) -> torch.Tensor:
"""Strict mask by simulating each macro sequentially on each active state.

Invalidates a macro if any sub-action is invalid or if sink is reached before
the sequence completes. Guarantees EXIT macro is valid if no macro is valid.
"""
B = states_active.batch_shape[0]
N = self.library.n_actions
device = states_active.device
mask = torch.zeros(B, N, dtype=torch.bool, device=device)

for b in range(B):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's possible to parallelize this.

s_curr = states_active[b : b + 1]
for j, seq in enumerate(self.library.id_to_sequence):
ok = True
s = s_curr
for k, a in enumerate(seq):
a_tensor = self.env.actions_from_tensor(
torch.tensor([[a]], device=device)
)
if not self.env.is_action_valid(s, a_tensor):
ok = False
break
s_next = self.env._step(s, a_tensor)
if s_next.is_sink_state.item() and k != len(seq) - 1:
ok = False
break
s = s_next
mask[b, j] = ok

# Ensure EXIT macro is available when none is valid
try:
exit_id = self.library.id_to_sequence.index([self.env.exit_action.item()])
except ValueError:
exit_id = N - 1
no_valid = ~mask.any(dim=1)
if no_valid.any():
mask[no_valid] = False
mask[no_valid, exit_id] = True
return mask

def compute(
self,
states_active: DiscreteStates,
ctx: Any,
step_mask: torch.Tensor,
**policy_kwargs: Any,
) -> tuple[Distribution, Any]:
logits = self.policy.forward_logits(states_active) # (B_active, N)
macro_mask = self._strict_macro_mask(states_active)
masked_logits = torch.where(
macro_mask, logits, torch.full_like(logits, -float("inf"))
)
dist = Categorical(logits=masked_logits)
ctx.current_estimator_output = None
return dist, ctx

def record_step(
self,
ctx: Any,
step_mask: torch.Tensor,
sampled_actions: torch.Tensor,
dist: Distribution,
save_logprobs: bool,
save_estimator_outputs: bool,
) -> None:
if save_logprobs:
lp_masked = dist.log_prob(sampled_actions)
step_lp = torch.full((ctx.batch_size,), 0.0, device=ctx.device)
step_lp[step_mask] = lp_masked
ctx.extras["macro_log_probs"].append(step_lp)
# No estimator outputs for macros by default
return

def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]:
out: dict[str, Optional[torch.Tensor]] = {
"log_probs": None,
"estimator_outputs": None,
}
macro_log_probs = ctx.extras.get("macro_log_probs", [])
if macro_log_probs:
out["macro_log_probs"] = torch.stack(macro_log_probs, dim=0)
else:
out["macro_log_probs"] = None
return out

def get_current_estimator_output(self, ctx: Any):
return None
Loading