diff --git a/docs/source/guides/estimator_adapters.md b/docs/source/guides/estimator_adapters.md new file mode 100644 index 00000000..63e34d0c --- /dev/null +++ b/docs/source/guides/estimator_adapters.md @@ -0,0 +1,180 @@ +# PolicyMixin: Policies and Rollouts + +Estimators become policy-capable by mixing in a small, uniform rollout API. This lets the same `Sampler` and probability utilities drive different estimator families (discrete, graph, conditional, recurrent) without bespoke glue code. + +This guide explains: +- The Policy rollout API and `RolloutContext` +- Vectorized vs non‑vectorized probability paths +- How policies integrate with the `Sampler` and probability calculators +- How to implement a new policy mixin or tailor the default behavior + +## Concepts and Goals + +A policy‑capable estimator exposes: +- `is_vectorized: bool` — whether the estimator can be evaluated in a single vectorized call (no per‑step carry). +- `init_context(batch_size, device, conditioning)` — allocate a per‑rollout context. +- `compute_dist(states_active, ctx, step_mask, ...) -> (Distribution, ctx)` — run the model, build a `torch.distributions.Distribution`. +- `log_probs(actions_active, dist, ctx, step_mask, vectorized, ...) -> (Tensor, ctx)` — evaluate log‑probs, optionally padded to batch. +- `get_current_estimator_output(ctx)` — access the last raw model output when requested. + +All per‑step artifacts (e.g., log‑probs, raw outputs, recurrent state) are owned by the `RolloutContext` and recorded by the mixin. + +## RolloutContext + +The `RolloutContext` is a lightweight container created once per rollout: +- `batch_size`, `device`, optional `conditioning` +- Optional `carry` (for recurrent policies) +- Per‑step buffers: `trajectory_log_probs`, `trajectory_estimator_outputs` +- `current_estimator_output` for cached reuse or immediate retrieval +- `extras: dict` for arbitrary policy‑specific data + +See `src/gfn/estimators.py` for the full definition. + +## PolicyMixin (vectorized, default) + +`PolicyMixin` enables vectorized evaluation by default (`is_vectorized=True`). + +- `init_context(batch_size, device, conditioning)` returns a fresh `RolloutContext` with empty buffers. +- `compute_dist(...)`: + - Slices `conditioning` by `step_mask` when provided; uses full `conditioning` when `step_mask=None` (vectorized). + - Optionally reuses `ctx.current_estimator_output` (e.g., PF with cached `trajectories.estimator_outputs`). + - Calls the estimator module and builds a `Distribution` via `to_probability_distribution`. + - When `save_estimator_outputs=True`, sets `ctx.current_estimator_output` and records a padded copy to `ctx.trajectory_estimator_outputs` for non‑vectorized calls. +- `log_probs(...)`: + - `vectorized=True`: returns raw `dist.log_prob(...)` (may include `-inf` for illegal actions) and optionally records to `trajectory_log_probs`. + - `vectorized=False`: strict inf‑check, pads to shape `(N,)` using `step_mask`, records when requested. + +Code reference (log‑probs behavior): `src/gfn/estimators.py`. + +## RecurrentPolicyMixin (per‑step) + +`RecurrentPolicyMixin` sets `is_vectorized=False` and threads a carry through steps: + +- `init_context(...)` requires the estimator to implement `init_carry(batch_size, device)`; stores the result in `ctx.carry`. +- `compute_dist(...)` must call the estimator as `(states_active, ctx.carry) -> (est_out, new_carry)`, update `ctx.carry`, build the `Distribution`, and record outputs when requested (with padding when masked). +- `log_probs(...)` follows the non‑vectorized path (pad and strict checks) and can reuse the same recording semantics as `PolicyMixin`. + +Code reference (carry update and padded recording): `src/gfn/estimators.py`. + +## Integration with the Sampler + +The `Sampler` uses the policy API directly. It creates a single `ctx` per rollout, then repeats `compute_dist` → sample → optional `log_probs` while some trajectories are active. Per‑step artifacts are recorded into `ctx` by the mixin when flags are enabled. + +Excerpt (per‑step call pattern): `src/gfn/samplers.py`. + +## Integration with probability calculators (PF/PB) + +Probability utilities in `utils/prob_calculations.py` branch on `is_vectorized` but call the same two methods in both paths: +- `compute_dist(states_active, ctx, step_mask=None or mask)` +- `log_probs(actions_active, dist, ctx, step_mask=None or mask, vectorized=...)` + +Key differences: +- Vectorized (fast path) + - `step_mask=None`, `vectorized=True`. + - May reuse cached estimator outputs by pre‑setting `ctx.current_estimator_output`. + - `log_probs` returns raw `dist.log_prob(...)` and does not raise on `-inf`. +- Non‑vectorized (per‑step path) + - Uses legacy‑accurate masks and alignments: + - PF (trajectories): `~states.is_sink_state[t] & ~actions.is_dummy[t]` + - PB (trajectories): aligns action at `t` with state at `t+1`, using `~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] & ~actions.is_dummy[t] & ~actions.is_exit[t]` (skips `t==0`). + - Transitions: legacy PB mask on `next_states` with `~actions.is_exit`. + - `log_probs` pads back to `(N,)` and raises if any `±inf` remains after masking. + +See `src/gfn/utils/prob_calculations.py` for full branching. + +## Built‑in policy‑capable estimators + +- `DiscretePolicyEstimator`: logits → `Categorical` with masking, optional temperature and epsilon‑greedy mixing in log‑space. +- `DiscreteGraphPolicyEstimator`: multi‑head logits (`TensorDict`) → `GraphActionDistribution` with per‑component masks and transforms. +- `RecurrentDiscretePolicyEstimator`: sequence models that maintain a `carry`; requires `init_carry` and returns `(logits, carry)` in `forward`. +- Conditional variants exist for state+conditioning architectures. + +## How to write a new policy (or mixin variant) + +Most users only need to implement `to_probability_distribution` (or reuse the provided ones). If you need a new interface or extra tracking, you can either: + +1) Use `PolicyMixin` (stateless, vectorized) and override `to_probability_distribution` on your estimator. +2) Use `RecurrentPolicyMixin` (per‑step, carry) and implement `init_carry` plus a `forward(states, carry)` that returns `(estimator_outputs, new_carry)`. +3) Create a custom mixin derived from `PolicyMixin` to tailor `compute_dist`/`log_probs` (e.g., custom caching, diagnostics). + +### Minimal stateless policy (discrete) + +```python +import torch +from torch import nn +from gfn.estimators import DiscretePolicyEstimator + +class SmallMLP(nn.Module): + def __init__(self, input_dim: int, output_dim: int): + super().__init__() + self.input_dim = input_dim + self.net = nn.Sequential( + nn.Linear(input_dim, 128), nn.ReLU(), nn.Linear(128, output_dim) + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + +# Forward policy over n_actions +policy = DiscretePolicyEstimator(module=SmallMLP(input_dim=32, output_dim=17), n_actions=17) +``` + +Use with the `Sampler`: + +```python +from gfn.samplers import Sampler + +sampler = Sampler(policy) +trajectories = sampler.sample_trajectories(env, n=64, save_logprobs=True) +``` + +### Minimal recurrent policy + +```python +import torch +from torch import nn +from gfn.estimators import RecurrentDiscretePolicyEstimator + +class TinyRNN(nn.Module): + def __init__(self, vocab_size: int, hidden: int): + super().__init__() + self.vocab_size = vocab_size + self.embed = nn.Embedding(vocab_size, hidden) + self.rnn = nn.GRU(hidden, hidden, batch_first=True) + self.head = nn.Linear(hidden, vocab_size) + + def forward(self, tokens: torch.Tensor, carry: dict[str, torch.Tensor]): + x = self.embed(tokens) + h0 = carry.get("h", torch.zeros(1, tokens.size(0), x.size(-1), device=tokens.device)) + y, h = self.rnn(x, h0) + logits = self.head(y) + return logits, {"h": h} + + def init_carry(self, batch_size: int, device: torch.device) -> dict[str, torch.Tensor]: + return {"h": torch.zeros(1, batch_size, self.embed.embedding_dim, device=device)} + +policy = RecurrentDiscretePolicyEstimator(module=TinyRNN(vocab_size=33, hidden=64), n_actions=33) +``` + +### Custom mixin variant (advanced) + +If you need to add diagnostics or custom caching, subclass `PolicyMixin` and override `compute_dist`/`log_probs` to interact with `ctx.extras`. + +```python +from typing import Any, Optional +from torch.distributions import Distribution +from gfn.estimators import PolicyMixin + +class TracingPolicyMixin(PolicyMixin): + def compute_dist(self, states_active, ctx, step_mask=None, save_estimator_outputs=False, **kw): + dist, ctx = super().compute_dist(states_active, ctx, step_mask, save_estimator_outputs, **kw) + ctx.extras.setdefault("num_compute_calls", 0) + ctx.extras["num_compute_calls"] += 1 + return dist, ctx + + def log_probs(self, actions_active, dist: Distribution, ctx: Any, step_mask=None, vectorized=False, save_logprobs=False): + lp, ctx = super().log_probs(actions_active, dist, ctx, step_mask, vectorized, save_logprobs) + ctx.extras.setdefault("last_lp_mean", lp.mean().detach()) + return lp, ctx +``` + +Keep `is_vectorized` consistent with your evaluation strategy. If you switch to `False`, ensure your estimator supports per‑step rollouts and masking semantics. diff --git a/docs/source/index.rst b/docs/source/index.rst index be9ef88c..e41a3104 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -10,6 +10,7 @@ guides/example guides/states_actions_containers guides/modules_estimators_samplers + guides/estimator_adapters guides/losses guides/creating_environments guides/advanced diff --git a/pyproject.toml b/pyproject.toml index f069d335..3cf0d7c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ tensordict = ">=0.6.1" torch = ">=2.6.0" torch_geometric = ">=2.6.1" dill = ">=0.3.8" +tokenizers = ">=0.15" # dev dependencies. black = { version = "24.3", optional = true } diff --git a/src/gfn/chunking/__init__.py b/src/gfn/chunking/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/gfn/chunking/adapters.py b/src/gfn/chunking/adapters.py new file mode 100644 index 00000000..b03e9f23 --- /dev/null +++ b/src/gfn/chunking/adapters.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from typing import Any, Optional + +import torch +from torch.distributions import Categorical, Distribution + +from gfn.chunking.policy import ChunkedPolicy +from gfn.env import DiscreteEnv +from gfn.samplers import AdapterContext, EstimatorAdapter +from gfn.states import DiscreteStates + + +class ChunkedAdapter(EstimatorAdapter): + """EstimatorAdapter that produces macro-level distributions using ChunkedPolicy. + + Forward-only in this PR. TODO(backward): support backward chunking by switching + stepping and termination criteria to the backward direction. + """ + + def __init__(self, env: DiscreteEnv, policy: ChunkedPolicy, library: Any) -> None: + self.env = env + self.policy = policy + self.library = library + self._is_backward = False # TODO(backward): allow backward chunking + + @property + def is_backward(self) -> bool: + return self._is_backward + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> AdapterContext: + ctx = AdapterContext( + batch_size=batch_size, device=device, conditioning=conditioning + ) + ctx.extras["macro_log_probs"] = [] # List[(N,)] + return ctx + + def _strict_macro_mask(self, states_active: DiscreteStates) -> torch.Tensor: + """Strict mask by simulating each macro sequentially on each active state. + + Invalidates a macro if any sub-action is invalid or if sink is reached before + the sequence completes. Guarantees EXIT macro is valid if no macro is valid. + """ + B = states_active.batch_shape[0] + N = self.library.n_actions + device = states_active.device + mask = torch.zeros(B, N, dtype=torch.bool, device=device) + + for b in range(B): + s_curr = states_active[b : b + 1] + for j, seq in enumerate(self.library.id_to_sequence): + ok = True + s = s_curr + for k, a in enumerate(seq): + a_tensor = self.env.actions_from_tensor( + torch.tensor([[a]], device=device) + ) + if not self.env.is_action_valid(s, a_tensor): + ok = False + break + s_next = self.env._step(s, a_tensor) + if s_next.is_sink_state.item() and k != len(seq) - 1: + ok = False + break + s = s_next + mask[b, j] = ok + + # Ensure EXIT macro is available when none is valid + try: + exit_id = self.library.id_to_sequence.index([self.env.exit_action.item()]) + except ValueError: + exit_id = N - 1 + no_valid = ~mask.any(dim=1) + if no_valid.any(): + mask[no_valid] = False + mask[no_valid, exit_id] = True + return mask + + def compute( + self, + states_active: DiscreteStates, + ctx: Any, + step_mask: torch.Tensor, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + logits = self.policy.forward_logits(states_active) # (B_active, N) + macro_mask = self._strict_macro_mask(states_active) + masked_logits = torch.where( + macro_mask, logits, torch.full_like(logits, -float("inf")) + ) + dist = Categorical(logits=masked_logits) + ctx.current_estimator_output = None + return dist, ctx + + def record_step( + self, + ctx: Any, + step_mask: torch.Tensor, + sampled_actions: torch.Tensor, + dist: Distribution, + save_logprobs: bool, + save_estimator_outputs: bool, + ) -> None: + if save_logprobs: + lp_masked = dist.log_prob(sampled_actions) + step_lp = torch.full((ctx.batch_size,), 0.0, device=ctx.device) + step_lp[step_mask] = lp_masked + ctx.extras["macro_log_probs"].append(step_lp) + # No estimator outputs for macros by default + return + + def finalize(self, ctx: Any) -> dict[str, Optional[torch.Tensor]]: + out: dict[str, Optional[torch.Tensor]] = { + "log_probs": None, + "estimator_outputs": None, + } + macro_log_probs = ctx.extras.get("macro_log_probs", []) + if macro_log_probs: + out["macro_log_probs"] = torch.stack(macro_log_probs, dim=0) + else: + out["macro_log_probs"] = None + return out + + def get_current_estimator_output(self, ctx: Any): + return None diff --git a/src/gfn/chunking/chunkers.py b/src/gfn/chunking/chunkers.py new file mode 100644 index 00000000..c8a0e362 --- /dev/null +++ b/src/gfn/chunking/chunkers.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import random +from abc import ABC, abstractmethod +from collections import Counter +from typing import TYPE_CHECKING, Any, Hashable, Sequence + +from tokenizers import Tokenizer +from tokenizers.models import BPE, WordPiece +from tokenizers.trainers import BpeTrainer, WordPieceTrainer + +if TYPE_CHECKING: # Avoid runtime import to break circular deps with env/containers + from gfn.containers.trajectories import Trajectories + + +class Chunker(ABC): + """Abstract base class for chunkers that propose new vocab tokens. + + Chunkers operate on trajectories and environment context and return + a sequence of token keys (any Hashable) to be added to the env vocab. + """ + + @abstractmethod + def propose_tokens( + self, + env: "Any", + trajectories: Trajectories, + n_tokens_to_add: int, + remove_old: bool, + ) -> Sequence[Hashable]: + raise NotImplementedError + + +class UniformChunker(Chunker): + """Proposes random bigrams of current non-exit tokens as tuples of ints.""" + + def propose_tokens( + self, + env: "Any", + trajectories: Trajectories, + n_tokens_to_add: int, + remove_old: bool, + ) -> Sequence[Hashable]: + # Build non-exit pool from current vocab ids. + non_exit_ids = [i for i in range(env.n_actions) if i != env.exit_token_id] + seen = set(env.vocab) + out: set[Hashable] = set() + while len(out) < n_tokens_to_add and len(out) < 10_000: + a, b = random.choice(non_exit_ids), random.choice(non_exit_ids) + candidate = (a, b) + if candidate not in seen: + out.add(candidate) + return list(out) + + +class _StringMapping: + """Utility to map env keys to strings suitable for tokenizers.""" + + def __init__(self, delimiter: str = "") -> None: + self.delimiter = delimiter + + def key_to_str(self, key: Hashable) -> str: + if isinstance(key, tuple): + return self.delimiter.join(str(x) for x in key) + return str(key) + + +class BPEChunker(Chunker): + def __init__(self, unk_token: str = "[UNK]", delimiter: str = "") -> None: + self.unk_token = unk_token + self.mapper = _StringMapping(delimiter=delimiter) + + def propose_tokens( + self, + env: "Any", + trajectories: Trajectories, + n_tokens_to_add: int, + remove_old: bool, + min_frequency: int = 5, + ) -> Sequence[Hashable]: + # Build corpus strings from trajectories via env tokenizer + corpus = env.trajectories_to_token_strings(trajectories) + + # Build initial vocab from current env keys mapped to strings + vocab_dict = {self.mapper.key_to_str(k): i for i, k in enumerate(env.vocab)} + tokenizer = Tokenizer(BPE(vocab_dict, [], unk_token=self.unk_token)) + + target_vocab_size = len(env.vocab) - 1 + n_tokens_to_add + trainer = BpeTrainer( + vocab_size=target_vocab_size, # type: ignore + special_tokens=[self.unk_token], # type: ignore + min_frequency=min_frequency, # type: ignore + ) + tokenizer.train_from_iterator(corpus, trainer=trainer) + + # Take the most common new tokens. + base_vocab = set(vocab_dict.keys()) + encodings = tokenizer.encode_batch(corpus) + counts = Counter() + for enc in encodings: + for tok in enc.tokens: + if tok not in base_vocab and tok != self.unk_token and len(tok) > 0: + counts[tok] += 1 + + top_new = [tok for tok, _ in counts.most_common(n_tokens_to_add)] + return top_new + + +class WordPieceChunker(Chunker): + def __init__(self, unk_token: str = "[UNK]", delimiter: str = "") -> None: + self.unk_token = unk_token + self.mapper = _StringMapping(delimiter=delimiter) + + def propose_tokens( + self, + env: "Any", + trajectories: Trajectories, + n_tokens_to_add: int, + remove_old: bool, + min_frequency: int = 5, + ) -> Sequence[Hashable]: + corpus = env.trajectories_to_token_strings(trajectories) + vocab_dict = {self.mapper.key_to_str(k): i for i, k in enumerate(env.vocab)} + tokenizer = Tokenizer( + WordPiece( + vocab=vocab_dict, + unk_token=self.unk_token, + max_input_chars_per_word=100, + ) + ) + target_vocab_size = len(env.vocab) - 1 + n_tokens_to_add + trainer = WordPieceTrainer( + vocab_size=target_vocab_size, + continuing_subword_prefix="##", # Defined prefix (removed later). + special_tokens=[self.unk_token], + min_frequency=min_frequency, + ) + tokenizer.train_from_iterator(corpus, trainer=trainer) + + # Take the most common new tokens. + base_vocab = set(vocab_dict.keys()) + encodings = tokenizer.encode_batch(corpus) + counts = Counter() + for enc in encodings: + for tok in enc.tokens: + if tok not in base_vocab and tok != self.unk_token and len(tok) > 0: + counts[tok.lstrip("##")] += 1 # Remove prefix if present. + + top_new = [tok for tok, _ in counts.most_common(n_tokens_to_add)] + return top_new diff --git a/src/gfn/chunking/policy.py b/src/gfn/chunking/policy.py new file mode 100644 index 00000000..8689d15b --- /dev/null +++ b/src/gfn/chunking/policy.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List + +import torch +from torch import nn + +from gfn.states import DiscreteStates + +if TYPE_CHECKING: + from gfn.env import ChunkedDiscreteEnvironment + + +class ChunkedPolicy(nn.Module): + """Compute logits over a macro library via state and macro embeddings. + + The `state_module` maps preprocessed states to a fixed-size embedding. The + `action_encoder` maps action sequences (macros) to the same embedding space. + Logits are the scaled dot products between state embeddings and macro embeddings. + """ + + def __init__( + self, + state_module: nn.Module, + action_encoder: ActionEncoder, + env: "ChunkedDiscreteEnvironment", + action_embedding_dim: int, + primitive_id_mapper: Callable[[int], int] | None = None, + ) -> None: + super().__init__() + self.state_module = state_module + self.action_encoder = action_encoder + self.env = env + self.action_embedding_dim = int(action_embedding_dim) + self.primitive_id_mapper: Callable[[int], int] = ( + primitive_id_mapper if primitive_id_mapper is not None else (lambda x: x) + ) + self.register_buffer("_library_embeddings", torch.empty(0)) + + @torch.no_grad() + def refresh_library_embeddings(self, device: torch.device) -> None: + # Build sequences from env vocab decoded to primitive ids + vocab = self.env.vocab + seqs: List[List[int]] = [] + max_len = 0 + + # TODO: This should rely on the env tokenizer instead of primitive_id_mapper. + for key in vocab: + decoded = list(self.env.decode_key_to_actions(key)) + mapped = [self.primitive_id_mapper(i) for i in decoded] + max_len = max(max_len, len(mapped)) + + # Ensure at least length 1 for encoder stability. + seqs.append(mapped if len(mapped) > 0 else [0]) # 0 is a placeholder FIXME. + + if len(seqs) == 0: + self._library_embeddings = torch.empty( + 0, self.action_embedding_dim, device=device + ) + return + + x = torch.full( + (len(seqs), max_len), + fill_value=0, + dtype=torch.long, + device=device, + ) + + # TODO: Vectorize this. + for i, s in enumerate(seqs): + if len(s) > 0: + x[i, : len(s)] = torch.tensor(s, dtype=torch.long, device=device) + + self._library_embeddings = self.action_encoder(x) # (N, D) + + def forward_logits(self, states: DiscreteStates) -> torch.Tensor: + state_emb = self.state_module(states) # (*B, D) + if ( + self._library_embeddings.numel() == 0 + or self._library_embeddings.shape[0] != self.env.n_actions + ): + self.refresh_library_embeddings(device=state_emb.device) + + logits = torch.einsum("bd,nd->bn", state_emb, self._library_embeddings) + logits = logits / (state_emb.shape[-1] ** 0.5) + + return logits + + +class ActionModel(nn.Module): + def __init__( + self, + n_primitive_actions: int, + hidden_dim: int = 256, + action_embedding_dimension: int = 128, + ) -> None: + super().__init__() + self.primitive_embedding = nn.Embedding(n_primitive_actions, hidden_dim) + self.rnn_encoder = nn.GRU(hidden_dim, hidden_dim, num_layers=2, batch_first=True) + self.out_layer = nn.Sequential( + nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, action_embedding_dimension) + ) + + def forward(self, x): # x: (B, L) + emb = self.primitive_embedding(x) + s, _ = self.rnn_encoder(emb) + out = s[:, -1] + out = self.out_layer(out) + return out + + +class PositionalEncoding(nn.Module): + """Minimal sinusoidal positional encoding for completeness. + + If a richer implementation exists elsewhere, prefer importing it. + """ + + pe: torch.Tensor # registered buffer + + def __init__(self, dim: int, dropout: float = 0.0, max_len: int = 512) -> None: + super().__init__() + self.dropout = nn.Dropout(dropout) + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len, dtype=torch.get_default_dtype()).unsqueeze(1) + div_term = torch.exp( + torch.arange( + 0, + dim, + 2, + dtype=torch.get_default_dtype(), + ) + * (-torch.log(torch.tensor(10000.0)) / dim) + ) + + pe[:, 0::2] = torch.sin(position * div_term) + if dim % 2 == 0: + pe[:, 1::2] = torch.cos(position * div_term) + else: + pe[:, 1::2] = torch.cos(position * div_term)[:, : dim // 2] + + self.register_buffer("pe", pe) + + def forward(self, x): + # x: (B, L, D) + L = x.size(1) + x = x + self.pe[:L] + return self.dropout(x) + + +class ActionEncoder(nn.Module): + def __init__( + self, + n_primitive_actions: int, + action_embedding_dimension: int, + hidden_dim: int, + num_layers: int, + num_head: int, + max_len: int = 60, + dropout: float = 0.0, + ) -> None: + super().__init__() + self.pos = PositionalEncoding(hidden_dim, dropout=dropout, max_len=max_len + 1) + self.embedding = nn.Embedding(n_primitive_actions, hidden_dim) + encoder_layers = nn.TransformerEncoderLayer( + hidden_dim, num_head, hidden_dim, dropout=dropout, batch_first=True + ) + self.encoder = nn.TransformerEncoder(encoder_layers, num_layers) + + # TODO: For the action encoder to work properly with macros of variable length, + # do we need the embedding layer to be recurrent. Or can we just use a simple + # embedding layer? + self.action_embedding_layer = nn.Linear(hidden_dim, action_embedding_dimension) + + def forward(self, x_ids): # (B, L) with 0 = PAD + pad = x_ids == 0 # (B, L) bool + x = self.embedding(x_ids) # (B, L, D) + x = self.pos(x) # (B, L, D) + x = self.encoder(x, src_key_padding_mask=pad) # mask pads in attention + + mask = (~pad).unsqueeze(-1) # (B, L, 1) + denom = mask.sum(dim=1).clamp_min(1) # (B, 1) + pooled = (x * mask).sum(dim=1) / denom # (B, D) + + return self.action_embedding_layer(pooled) diff --git a/src/gfn/env.py b/src/gfn/env.py index e4296c16..e17be5f5 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -1,11 +1,13 @@ from abc import ABC, abstractmethod -from typing import Optional, Tuple, cast +from contextlib import contextmanager +from typing import Any, Callable, Hashable, Optional, Sequence, Tuple, cast import torch from torch_geometric.data import Data as GeometricData from gfn.actions import Actions, GraphActions -from gfn.states import DiscreteStates, GraphStates, States +from gfn.chunking.chunkers import Chunker +from gfn.states import ChunkedStates, DiscreteStates, GraphStates, States from gfn.utils.common import default_fill_value_for_dtype, ensure_same_device, set_seed # Errors @@ -328,6 +330,9 @@ def _step(self, states: States, actions: Actions) -> States: # For the indices where the new states are not sink states (i.e., where the # state is not already a sink and the action is not exit), update those # positions with the result of the environment's step function. + # TODO: Ensure that the step function returns a States instance with the same + # type as the States class. Right now, we initialize with -inf, which assumes + # a float type. For now, I will handle casting outside of this method. new_states = self.States.make_sink_states( states.batch_shape, device=states.device ) @@ -659,6 +664,21 @@ def _step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: """ new_states = super()._step(states, actions) new_states = cast(DiscreteStates, new_states) + + # Ensure dtypes of all tensor fields match the input states. + # TODO: We should probably fix this at the base Env class level, but + # I want to do it in a follow up PR. + if new_states.tensor.dtype != states.tensor.dtype: + new_states.tensor = new_states.tensor.to(dtype=states.tensor.dtype) + if new_states.forward_masks.dtype != states.forward_masks.dtype: + new_states.forward_masks = new_states.forward_masks.to( + dtype=states.forward_masks.dtype + ) + if new_states.backward_masks.dtype != states.backward_masks.dtype: + new_states.backward_masks = new_states.backward_masks.to( + dtype=states.backward_masks.dtype + ) + self.update_masks(new_states) return new_states @@ -676,6 +696,21 @@ def _backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteSt """ new_states = super()._backward_step(states, actions) new_states = cast(DiscreteStates, new_states) + + # Ensure dtypes of all tensor fields match the input states. + # TODO: We should probably fix this at the base Env class level, but + # I want to do it in a follow up PR. + if new_states.tensor.dtype != states.tensor.dtype: + new_states.tensor = new_states.tensor.to(dtype=states.tensor.dtype) + if new_states.forward_masks.dtype != states.forward_masks.dtype: + new_states.forward_masks = new_states.forward_masks.to( + dtype=states.forward_masks.dtype + ) + if new_states.backward_masks.dtype != states.backward_masks.dtype: + new_states.backward_masks = new_states.backward_masks.to( + dtype=states.backward_masks.dtype + ) + self.update_masks(new_states) return new_states @@ -760,6 +795,417 @@ def terminating_states(self) -> DiscreteStates: ) +class ChunkedDiscreteEnvironment(DiscreteEnv): + """Discrete environment with chunking-aware action vocab management. + + Intended behavior: + - Exit invariance: the exit action is represented by a fixed, non-executable + sentinel key at its index and never moves when the vocab grows. + - Vocab growth: new primitive/macro tokens are appended (no reindexing of + existing actions). Token-to-id and id-to-token mappings are maintained by + the environment. + - Self-healing states: the custom ``States`` class returned by + :meth:`make_states_class` automatically resizes forward/backward masks to + ``env.n_actions`` and overlays environment-wide constraints (soft-disabled + actions and strict macro feasibility) whenever masks are created or the + batch shape changes. + - Chunker integration: arbitrary chunkers can propose new tokens; helpers + are provided to build string corpora from trajectories for text-based + chunkers (e.g., BPE/WordPiece). + """ + + def __init__( + self, + n_actions: int, + s0: torch.Tensor, + state_shape: Tuple | int, + *, + action_shape: Tuple | int = (1,), + dummy_action: Optional[torch.Tensor] = None, + exit_action: Optional[torch.Tensor] = None, + sf: Optional[torch.Tensor] = None, + check_action_validity: bool = True, + tokenizer: Optional[Callable[[Sequence[int]], str]] = None, + detokenizer: Optional[Callable[[str], Sequence[int]]] = None, + ) -> None: + # Delegate exit action handling to DiscreteEnv (base behavior). + super().__init__( + n_actions=n_actions, + s0=s0, + state_shape=state_shape, + action_shape=action_shape, + dummy_action=dummy_action, + exit_action=exit_action, + sf=sf, + check_action_validity=check_action_validity, + ) + + # Fixed exit action id derived from parent init. + self.exit_token_id: int = int(self.exit_action.item()) + + # Re-entrancy guard for macro overlay recursion + self._macro_overlay_depth: int = 0 + + # Hashable-keyed vocab: start with primitive ids, then exit sentinel at its index. + self._exit_key: str = "" + self.id_to_token_key: list[Hashable] = list(range(self.n_actions - 1)) + [ + self._exit_key + ] + self.token_key_to_id: dict[Hashable, int] = { + k: i for i, k in enumerate(self.id_to_token_key) + } + + # Initially, all non-exit tokens are considered atomic. + self._atomic_token_ids: set[int] = { + i for i in range(self.n_actions) if i != self.exit_token_id + } + self._disabled_token_ids: set[int] = set() + + self._tokenizer: Callable[[Sequence[int]], str] + self._tokenizer = tokenizer if tokenizer is not None else self._default_tokenizer + + # Optional detokenizer to convert string keys to primitive action sequences. + # TODO: the tokenizer should have methods for this (instead of a unique detokenizer). + self._detokenizer: Optional[Callable[[str], Sequence[int]]] = detokenizer + + @staticmethod + def _default_tokenizer(ids: Sequence[int]) -> str: + # Example: [1, 30, 7] -> "1,30,7," + return ",".join(str(i) for i in ids) + "," + + @property + def vocab(self) -> list[Hashable]: + return list(self.id_to_token_key) + + def add_tokens(self, new_keys: Sequence[Hashable]) -> list[int]: + """Append new token keys to the vocab, assigning fresh ids. + + Args: + new_keys: Iterable of integer token keys to add. + + Returns: + The list of ids assigned to the newly added keys (in insertion order). + """ + new_ids: list[int] = [] + for key in new_keys: + if key in self.token_key_to_id: + continue + + new_id = len(self.id_to_token_key) + self.id_to_token_key.append(key) + self.token_key_to_id[key] = new_id + new_ids.append(new_id) + + if new_ids: + self.n_actions = len(self.id_to_token_key) + + return new_ids + + def disable_tokens(self, tokens_or_ids: Sequence[int]) -> None: + for tid in tokens_or_ids: + tid_i = int(tid) + if tid_i != self.exit_token_id: + self._disabled_token_ids.add(tid_i) + + def enable_tokens(self, tokens_or_ids: Sequence[int]) -> None: + for tid in tokens_or_ids: + tid_i = int(tid) + self._disabled_token_ids.discard(tid_i) + + def chunk_and_update_vocab( + self, + trajectories: Any, + chunker: Chunker, + n_tokens_to_add: int, + remove_old: bool = False, + ) -> list[int]: + """Calls a user-provided chunker and updates the vocab with proposed keys. + + The chunker returns token keys (typically integers) to be appended. If + remove_old is True, previously learned non-atomic tokens that were not + just added will be soft-disabled via masks. + """ + # Expect a Chunker-like instance exposing propose_tokens(env, trajectories, n, remove_old) + proposed_keys = list( + chunker.propose_tokens(self, trajectories, n_tokens_to_add, remove_old) + ) + new_ids = self.add_tokens(proposed_keys) + if remove_old: + keep = set(self._atomic_token_ids) | {self.exit_token_id} | set(new_ids) + to_disable = [i for i in range(self.n_actions) if i not in keep] + self.disable_tokens(to_disable) + return new_ids + + def trajectories_to_action_sequences(self, trajs: Any) -> list[list[int]]: + # actions: (T, B) after squeeze, terminating_idx: (B,) + actions = trajs.actions.tensor.squeeze(-1) + term = trajs.terminating_idx + out: list[list[int]] = [] + T, B = actions.shape + for i in range(B): + L = int(term[i].item()) + idxs = [ + int(a) for a in actions[:L, i].tolist() if int(a) != self.exit_token_id + ] + out.append(idxs) + return out + + def trajectories_to_token_strings(self, trajs: Any) -> list[str]: + seqs = self.trajectories_to_action_sequences(trajs) + return [self._tokenizer(seq) for seq in seqs] + + def apply_soft_disabled_to_forward_masks(self, states: DiscreteStates) -> None: + if self._disabled_token_ids: + ids = torch.tensor(sorted(self._disabled_token_ids), device=states.device) + states.forward_masks[..., ids] = False + + @contextmanager + def macro_mask_guard(self): + """Temporarily disable macro overlay application to avoid recursion. + + Increments an environment-local depth counter; while non-zero, calls to + apply_macro_forward_mask will no-op. Always decremented in a finally block. + """ + self._macro_overlay_depth += 1 + try: + yield + finally: + self._macro_overlay_depth -= 1 + assert self._macro_overlay_depth >= 0 + + def make_states_class(self) -> type[DiscreteStates]: + """Returns the DiscreteStates class for this environment. + + Returns: + A type of a subclass of DiscreteStates with environment-specific + functionalities. + """ + env = self + + class ChunkedEnvStates(ChunkedStates): + """States for chunked env that auto-resize and overlay masks. + + Responsibilities: + - Keep ``n_actions`` synchronized with the parent environment. + - Lazily resize masks to match ``env.n_actions`` after construction, + extension, or padding. + - Apply environment overlays (soft-disables and strict macro feasibility) + after any (re)allocation of masks. + """ + + state_shape = env.state_shape + s0 = env.s0 + sf = env.sf + make_random_states = env.make_random_states + n_actions = env.n_actions + device = env.device + + # wire hooks into the shared ChunkedStates base + @staticmethod + def get_n_actions() -> int: + return env.n_actions + + @staticmethod + def overlay_masks(s: "ChunkedStates") -> None: + env.apply_soft_disabled_to_forward_masks(s) + env.apply_macro_forward_mask(s) + + return ChunkedEnvStates + + @abstractmethod + def update_masks(self, states: DiscreteStates) -> None: + """Subclasses must compute env-specific masks and then call overlay helper. + + Example pattern: + states.set_nonexit_action_masks(cond=..., allow_exit=True) + self.apply_soft_disabled_to_forward_masks(states) + """ + ... + + def decode_key_to_actions(self, key: Hashable) -> Sequence[int]: + """Decodes a vocab key (potential macro) into a sequence of primitive actions. + + - int -> [int] + - tuple[int,...] -> list(tuple) + - str -> detokenizer(str) if provided; otherwise empty sequence (non-executable) + """ + if isinstance(key, int): + return [int(key)] + if isinstance(key, tuple): + # assume a tuple of ints + return [int(x) for x in key] + if isinstance(key, str): + if key == self._exit_key: + return [] + if self._detokenizer is not None: + return list(self._detokenizer(key)) + + raise ValueError(f"Invalid key: {key}") + + def _decode_action_id_to_sequence(self, action_id: int) -> Sequence[int]: + key = self.id_to_token_key[action_id] + return self.decode_key_to_actions(key) + + def is_macro_id(self, action_id: int) -> bool: + seq = self._decode_action_id_to_sequence(action_id) + return len(seq) > 1 + + def _compute_macro_mask_flat(self, states: DiscreteStates) -> torch.Tensor: + """Compute macro feasibility for a 1D batch of ChunkedStates. + + Returns: (B, n_actions) boolean mask for macros only (primitives/exit left True). + """ + if not isinstance(states, ChunkedStates): + raise TypeError("compute macro mask requires ChunkedStates") + + assert len(states.batch_shape) == 1 + B = states.batch_shape[0] + macro_mask = torch.ones( + B, self.n_actions, dtype=torch.bool, device=states.device + ) + + # Collect macro sequences + macro_sequences: dict[int, Sequence[int]] = {} + for action_id in range(self.n_actions): + seq = self._decode_action_id_to_sequence(action_id) + if len(seq) > 1: + macro_sequences[action_id] = seq + + if not macro_sequences: + return macro_mask # no macros to validate. + + for aid, seq in macro_sequences.items(): + # Local working copy of states; do not mutate caller's states + s_curr = states.clone() + valid_vec = torch.ones(B, dtype=torch.bool, device=states.device) + + for primitive_id in seq: + # Per-state validity for this primitive at the current step. + step_valid = s_curr.forward_masks[:, primitive_id] + to_step = valid_vec & step_valid + valid_vec &= step_valid # Update cumulative validity + + if bool(to_step.any().item()): + idx = torch.where(to_step)[0] + n = idx.numel() + a_tensor = torch.full( + (n, *self.action_shape), + primitive_id, + device=states.device, + dtype=torch.long, + ) + a = self.actions_from_tensor(a_tensor) + next_sub = super()._step(s_curr[idx], a) + s_curr[idx] = next_sub + + # Refresh masks after stepping (guard prevents recursion) + self.update_masks(s_curr) + + macro_mask[:, aid] = valid_vec + + return macro_mask + + def compute_strict_macro_forward_mask(self, states: DiscreteStates) -> torch.Tensor: + """Returns a mask of shape (batch_shape*, n_actions) validating macros. + + Supports (B) and (T,B). Requires ChunkedStates. + """ + # Enforce ChunkedStates + if not isinstance(states, ChunkedStates): + raise TypeError("compute_strict_macro_forward_mask requires ChunkedStates") + + with self.macro_mask_guard(): + if len(states.batch_shape) == 1: + return self._compute_macro_mask_flat(states) + elif len(states.batch_shape) == 2: + T, B = states.batch_shape + # Horizon pre-check: macros longer than remaining steps are invalid + macro_lengths = torch.tensor( + [ + len(self._decode_action_id_to_sequence(i)) + for i in range(self.n_actions) + ], + device=states.device, + dtype=torch.long, + ) + t_idx = torch.arange(T, device=states.device) + remaining = (T - t_idx).view(T, 1) + horizon_ok = ( + macro_lengths.view(1, self.n_actions) <= remaining.view(T, 1) + ).to(torch.bool) + + # Compute feasibility ignoring horizon, then AND with horizon_ok + flat = states.flatten() + flat_mask = self._compute_macro_mask_flat(flat) # (T*B, n_actions) + tb_mask = flat_mask.view(T, B, self.n_actions) + # Broadcast horizon_ok over B + tb_mask = tb_mask & horizon_ok.view(T, 1, self.n_actions) + return tb_mask + else: + raise ValueError( + f"Expected batch_shape (B) or (T,B), got {states.batch_shape}" + ) + + def apply_macro_forward_mask(self, states: DiscreteStates) -> None: + # Skip macro overlay while inside macro feasibility computation + if getattr(self, "_macro_overlay_depth", 0) > 0: + return + macro_mask = self.compute_strict_macro_forward_mask(states) + states.forward_masks = states.forward_masks & macro_mask + + def _step(self, states: ChunkedStates, actions: Actions) -> ChunkedStates: + """Overrides base to unroll macro actions sequentially. + + Non-macro actions are delegated to the base implementation. + """ + assert states.batch_shape == actions.batch_shape + B = states.batch_shape[0] + + # Identify macro vs non-macro per batch element + action_ids = actions.tensor.view(B) + macro_flags = torch.zeros(B, dtype=torch.bool, device=states.device) + for i in range(B): + macro_flags[i] = self.is_macro_id(int(action_ids[i].item())) + + # Fast path: no macros found. + if not bool(macro_flags.any().item()): + return cast(ChunkedStates, super()._step(states, actions)) + + # Split states/actions + out_states = self.States.make_sink_states( + states.batch_shape, device=states.device + ) + + # Handle non-macro subset via base + non_macro_idx = ~macro_flags + if bool(non_macro_idx.any().item()): + nm_states = states[non_macro_idx] + nm_actions = actions[non_macro_idx] + nm_next = super()._step(nm_states, nm_actions) + out_states[non_macro_idx] = nm_next + + # Handle macros by sequential unroll + if bool(macro_flags.any().item()): + m_states = states[macro_flags] + m_action_ids = action_ids[macro_flags] + # iterate each macro in the smaller batch + curr = m_states + for j in range(curr.batch_shape[0]): + aid = int(m_action_ids[j].item()) + seq = self._decode_action_id_to_sequence(aid) + s = curr[j : j + 1] + for primitive_id in seq: + a_tensor = self.actions_from_tensor( + torch.tensor([[primitive_id]], device=states.device) + ) + s = super()._step(s, a_tensor) + out_states[macro_flags][j : j + 1] = s + + # Update masks for the resulting batch + self.update_masks(cast(DiscreteStates, out_states)) + return cast(ChunkedStates, out_states) + + class GraphEnv(Env): """Base class for graph-based environments. diff --git a/src/gfn/estimators.py b/src/gfn/estimators.py index 24f7c158..f30deac4 100644 --- a/src/gfn/estimators.py +++ b/src/gfn/estimators.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import Any, Optional +from typing import Any, Callable, Dict, List, Optional, Protocol, cast, runtime_checkable import torch import torch.nn as nn @@ -11,6 +11,10 @@ from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, States from gfn.utils.distributions import GraphActionDistribution, UnsqueezedCategorical +from gfn.utils.handlers import ( + has_conditioning_exception_handler, + no_conditioning_exception_handler, +) REDUCTION_FUNCTIONS = { "mean": torch.mean, @@ -19,6 +23,285 @@ } +class RolloutContext: + """Structured per‑rollout state owned by estimators. + + Holds rollout invariants and optional per‑step buffers; use ``extras`` for + estimator‑specific fields without changing the class shape. + """ + + __slots__ = ( + "batch_size", + "device", + "conditioning", + "carry", + "trajectory_log_probs", + "trajectory_estimator_outputs", + "current_estimator_output", + "extras", + ) + + def __init__( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> None: + self.batch_size = batch_size + self.device = device + self.conditioning = conditioning + self.carry = None + self.trajectory_log_probs: List[torch.Tensor] = [] + self.trajectory_estimator_outputs: List[torch.Tensor] = [] + self.current_estimator_output: Optional[torch.Tensor] = None + self.extras: Dict[str, Any] = {} + + +@runtime_checkable +class PolicyEstimatorProtocol(Protocol): + """Static-typing surface for estimators that are policy-capable. + + This protocol captures the methods provided by the PolicyMixin so that external + code (e.g., samplers/probability calculators) can use a precise type rather than + relying on dynamic attributes. This helps static analyzers avoid false positives + like "Tensor is not callable" when calling mixin methods. + """ + + is_vectorized: bool + + def init_context( # noqa: E704 + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> Any: ... + + def compute_dist( # noqa: E704 + self, + states_active: States, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: ... + + def log_probs( # noqa: E704 + self, + actions_active: torch.Tensor, + dist: Distribution, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + vectorized: bool = False, + **kwargs: Any, + ) -> tuple[torch.Tensor, Any]: ... + + +class PolicyMixin: + """Mixin enabling an `Estimator` to act as a policy (distribution over actions). + + Provides the generic rollout API (`init_context`, `compute_dist`, `log_probs`) + directly on the estimator. Standard policies should inherit from this mixin. + """ + + @property + def is_vectorized(self) -> bool: + """Used for vectorized probability calculations.""" + return True + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> RolloutContext: + """Create a new per-rollout context. + + Stores rollout invariants (batch size, device, optional conditioning) and + initializes empty buffers for per-step artifacts. + + """ + return RolloutContext( + batch_size=batch_size, device=device, conditioning=conditioning + ) + + def compute_dist( + self, + states_active: States, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + """Run the estimator for active rows and build an action Distribution. + + Args: + states_active: The states to run the estimator on. + ctx: The context to run the estimator on. + step_mask: The mask to slice the conditioning to the active subset. + save_estimator_outputs: Whether to save the estimator outputs. + **policy_kwargs: Additional keyword arguments to pass to the estimator. + + Returns: + A tuple containing the distribution and the context. + + - Uses `step_mask` to slice conditioning to the active subset. When `step_mask` + is None, the estimator running in a vectorized context. + - Saves the raw estimator output in `ctx.current_estimator_output` for + optional recording in `record_step`. + """ + precopmputed_estimator_outputs = getattr(ctx, "current_estimator_output", None) + + if step_mask is None and precopmputed_estimator_outputs is not None: + expected_bs = states_active.batch_shape[0] + if precopmputed_estimator_outputs.shape[0] != expected_bs: + raise RuntimeError( + "current_estimator_output batch size does not match active states. " + f"Got {precopmputed_estimator_outputs.shape[0]}, expected {expected_bs}. " + "This indicates stale cache reuse; ensure per-step masking when setting " + "ctx.current_estimator_output and clear it when not valid." + ) + estimator_outputs = precopmputed_estimator_outputs + + # Otherwise, compute the estimator outputs. + else: + cond_active = None + if ctx.conditioning is not None: + if step_mask is None: + cond_active = ctx.conditioning + else: + cond_active = ctx.conditioning[step_mask] + + # Call estimator with or without conditioning (ensures preprocessor is applied). + if cond_active is not None: + with has_conditioning_exception_handler("estimator", self): + estimator_outputs = self(states_active, cond_active) # type: ignore[misc,call-arg] + else: + with no_conditioning_exception_handler("estimator", self): + estimator_outputs = self(states_active) # type: ignore[misc] + + # Build the distribution. + dist = self.to_probability_distribution( + states_active, estimator_outputs, **policy_kwargs + ) + + # Save current estimator output only when requested. + if save_estimator_outputs: + ctx.current_estimator_output = estimator_outputs + + # If we are in a non-vectorized path (masked), append a padded copy to trajectory. + if step_mask is not None: + padded = torch.full( + (ctx.batch_size,) + estimator_outputs.shape[1:], + -float("inf"), + device=ctx.device, + ) + padded[step_mask] = estimator_outputs + ctx.trajectory_estimator_outputs.append(padded) + + else: + ctx.current_estimator_output = None + + return dist, ctx + + def log_probs( + self, + actions_active: torch.Tensor, + dist: Distribution, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + vectorized: bool = False, + save_logprobs: bool = False, + ) -> tuple[torch.Tensor, Any]: + """Compute log-probs, optionally padding back to full batch when non-vectorized.""" + lp = dist.log_prob(actions_active) + + if vectorized: + if save_logprobs: + ctx.trajectory_log_probs.append(lp) + return lp, ctx + + # Non-vectorized path strict check. None of these should be -inf after masking. + if torch.any(torch.isinf(lp)): + raise RuntimeError("Log probabilities are inf. This should not happen.") + + assert step_mask is not None, "step_mask is required when vectorized=False" + step_lp = torch.full((ctx.batch_size,), 0.0, device=ctx.device, dtype=lp.dtype) + step_lp[step_mask] = lp + + if save_logprobs: + ctx.trajectory_log_probs.append(step_lp) + + return step_lp, ctx + + def get_current_estimator_output(self, ctx: Any) -> Optional[torch.Tensor]: + """Expose the most recent per-step estimator output saved during `compute`.""" + return getattr(ctx, "current_estimator_output", None) + + +class RecurrentPolicyMixin(PolicyMixin): + """Mixin for recurrent policies that maintain and update a rollout carry.""" + + @property + def is_vectorized(self) -> bool: + return False + + def init_context( + self, + batch_size: int, + device: torch.device, + conditioning: Optional[torch.Tensor] = None, + ) -> RolloutContext: + ctx = super().init_context(batch_size, device, conditioning) + init_carry = getattr(self, "init_carry", None) + if not callable(init_carry): + raise TypeError( + "Recurrent policy requires init_carry(batch_size: int, device: torch.device)." + ) + init_carry_fn = cast(Callable[[int, torch.device], Any], init_carry) + ctx.carry = init_carry_fn(batch_size, device) + + return ctx + + def compute_dist( + self, + states_active: States, + ctx: Any, + step_mask: Optional[torch.Tensor] = None, + save_estimator_outputs: bool = False, + **policy_kwargs: Any, + ) -> tuple[Distribution, Any]: + """Run estimator with carry and update it. + + Differs from the default PolicyMixin by calling + `estimator(states_active, ctx.carry) -> (est_out, new_carry)`, storing the + updated carry and saving `current_estimator_output` before building the + Distribution. + """ + estimator_outputs, new_carry = self(states_active, ctx.carry) # type: ignore + ctx.carry = new_carry + dist = self.to_probability_distribution( + states_active, + estimator_outputs, + **policy_kwargs, + ) + + # Save current estimator output only when requested. + if save_estimator_outputs: + ctx.current_estimator_output = estimator_outputs + + if step_mask is not None: + padded = torch.full( + (ctx.batch_size,) + estimator_outputs.shape[1:], + -float("inf"), + device=ctx.device, + ) + padded[step_mask] = estimator_outputs + ctx.trajectory_estimator_outputs.append(padded) + else: + ctx.current_estimator_output = None + + return dist, ctx + + class Estimator(ABC, nn.Module): r"""Base class for modules mapping states to distributions or scalar values. @@ -172,8 +455,8 @@ def __init__( Args: module: The neural network module to use. - preprocessor: Preprocessor object that transforms states to tensors. If None, - uses `IdentityPreprocessor` with the module's input_dim. + preprocessor: Preprocessor object that transforms states to tensors. If + None, uses `IdentityPreprocessor` with the module's input_dim. reduction: String name of one of the REDUCTION_FUNCTIONS keys. """ super().__init__(module, preprocessor, False) @@ -219,6 +502,12 @@ class LogitBasedEstimator(Estimator): This class is used to define estimators that output logits, which can be used to construct probability distributions. + + Attributes: + module: The neural network module to use. + preprocessor: Preprocessor object that transforms raw States objects to tensors. + is_backward: Flag indicating whether this estimator is for backward policy, + i.e., is used for predicting probability distributions over parents. """ @staticmethod @@ -365,9 +654,13 @@ def _compute_logits_for_distribution( class ConditionalLogZEstimator(ScalarEstimator): """Conditional logZ estimator. - This estimator is used to estimate the logZ of a GFlowNet from a conditioning tensor. - Since conditioning is a tensor, it does not have a preprocessor. Reduction is used - to aggregate the outputs of the module into a single scalar. + This estimator is used to estimate the logZ of a GFlowNet from a conditioning + tensor. Since conditioning is a tensor, it does not have a preprocessor. + Reduction is used to aggregate the outputs of the module into a single scalar. + + Attributes: + module: The neural network module to use. + reduction: String name of one of the REDUCTION_FUNCTIONS keys. """ def __init__(self, module: nn.Module, reduction: str = "mean"): @@ -377,7 +670,7 @@ def _calculate_module_output(self, input: torch.Tensor) -> torch.Tensor: return self.module(input) -class DiscretePolicyEstimator(LogitBasedEstimator): +class DiscretePolicyEstimator(PolicyMixin, LogitBasedEstimator): r"""Forward or backward policy estimators for discrete environments. Estimates either: @@ -658,7 +951,7 @@ def to_probability_distribution( raise NotImplementedError -class DiscreteGraphPolicyEstimator(LogitBasedEstimator): +class DiscreteGraphPolicyEstimator(PolicyMixin, LogitBasedEstimator): r"""Forward or backward policy estimators for graph-based environments. Estimates either, where $s$ and $s'$ are graph states: @@ -798,3 +1091,133 @@ def expected_output_dim(self) -> Optional[int]: None, as the output_dim of a TensorDict is not well-defined. """ return None + + +class RecurrentDiscretePolicyEstimator(RecurrentPolicyMixin, DiscretePolicyEstimator): + """Discrete policy estimator for recurrent architectures with explicit carry. + + Many sequence models (e.g., RNN/LSTM/GRU/Transformer in autoregressive mode) + maintain a recurrent hidden state ("carry") that must be threaded through + successive calls during sampling. This class formalizes that pattern for + GFlowNet policies by: + + - Exposing a forward signature ``forward(states, carry) -> (logits, carry)`` + so the policy can update and return the next carry at each step. + - Requiring an ``init_carry(batch_size, device)`` method to allocate the + initial hidden state for a rollout. + - Ensuring the per-step output (``logits`` over actions) is derived from the + latest token/time step while the internal model may process sequences. + + The sampler uses a ``RecurrentPolicyMixin`` which calls this estimator + with the current carry, updates the carry on every step, and records + per-step artifacts. Non-recurrent estimators should use the default PolicyMixin + and the standard ``DiscretePolicyEstimator`` base class instead. + + Notes + ----- + - Forward is intended for on-policy generation; off-policy evaluation over + entire trajectories typically requires different batching and masking. + - ``init_carry`` is a hard requirement for compatibility with the recurrent + PolicyMixin. + + Attributes: + module: The neural network module to use. + n_actions: Total number of actions in the discrete environment. + preprocessor: Preprocessor object that transforms states to tensors. + is_backward: Flag indicating whether this estimator is for backward policy, + i.e., is used for predicting probability distributions over parents. + """ + + def __init__( + self, + module: nn.Module, + n_actions: int, + preprocessor: Preprocessor | None = None, + is_backward: bool = False, + ): + """Initializes a RecurrentDiscretePolicyEstimator. + + Args: + module: The neural network module to use. + n_actions: Total number of actions in the discrete environment. + preprocessor: Preprocessor object that transforms states to tensors. + """ + if preprocessor is None: + preprocessor = IdentityPreprocessor(output_dim=None) + + super().__init__( + module=module, + n_actions=n_actions, + preprocessor=preprocessor, + is_backward=is_backward, + ) + + def forward( + self, + states: States, + carry: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Forward pass of the module. + + Args: + states: The input states. + carry: The carry from the previous step. + + Returns: + The output of the module, as a tensor of shape (*batch_shape, output_dim). + """ + # Prepare integer token sequences without -1 padding and use a BOS index. + # We infer the active sequence length per row from (token != -1). + tokens = states.tensor + if not torch.is_floating_point(tokens): + tokens = tokens.long() + else: + tokens = tokens.to(dtype=torch.long) + + # Replace padding (-1) with BOS index expected by the sequence model. + # RecurrentDiscreteSequenceModel reserves index == vocab_size for BOS. + bos_index = getattr(self.module, "vocab_size", self.n_actions - 1) + tokens = torch.where( + tokens < 0, torch.as_tensor(bos_index, device=tokens.device), tokens + ) + + # Determine a common prefix length across the (active) batch. + # Active rows in a rollout step share the same length; use max for safety. + # We still derive length from original states.tensor where -1 marks padding. + original = states.tensor + valid_mask = original >= 0 + if valid_mask.ndim == 1: + max_len = int(valid_mask.sum().item()) + else: + max_len = int(valid_mask.sum(dim=-1).max().item()) + if max_len <= 0: + max_len = 1 # Ensure at least BOS is processed + + # Trim to the common active prefix length and run the sequence model. + seq_input = tokens[..., :max_len] + logits, carry = self.module(seq_input, carry) + + # Use the logits corresponding to the last processed token. + logits = logits[:, -1, :] # (b, n_actions) + + if self.expected_output_dim is not None: + assert logits.shape[-1] == self.expected_output_dim, ( + f"Module output shape {logits.shape} does not match expected output " + f"dimension {self.expected_output_dim}" + ) + + return logits, carry + + def init_carry( + self, + batch_size: int, + device: torch.device, + ) -> dict[str, torch.Tensor]: + init_carry = getattr(self.module, "init_carry", None) + if not callable(init_carry): + raise NotImplementedError( + "Module does not implement init_carry(batch_size, device)." + ) + init_carry_fn = cast(Callable[[int, torch.device], Any], init_carry) + + return init_carry_fn(batch_size, device) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 0c5efe64..bd228268 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -171,7 +171,10 @@ class PFBasedGFlowNet(GFlowNet[TrainingSampleType], ABC): """ def __init__( - self, pf: Estimator, pb: Estimator | None, constant_pb: bool = False + self, + pf: Estimator, + pb: Estimator | None, + constant_pb: bool = False, ) -> None: """Initializes a PFBasedGFlowNet instance. @@ -183,6 +186,7 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + """ super().__init__() # Technical note: pb may be constant for a variety of edge cases, for example, @@ -212,6 +216,27 @@ def __init__( self.pb = pb self.constant_pb = constant_pb + # Advisory: recurrent PF with non-recurrent PB is unusual + # (tree DAGs typically prefer pb=None with constant_pb=True). + # Import locally to avoid circular imports during module import time. + from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore + + if isinstance(self.pf, RecurrentDiscretePolicyEstimator) and isinstance( + self.pb, Estimator + ): + warnings.warn( + "Using a recurrent PF, which is only valid for tree DAGs, with a " + "non-recurrent PB is unusual. " + "Consider using pb=None with constant_pb=True for tree DAGs.", + ) + # Disallow recurrent PB estimators universally. + # I'm not actually sure we should disallow this. + if isinstance(self.pb, RecurrentDiscretePolicyEstimator): + raise TypeError( + "Recurrent PB estimators are not supported. Use a non-recurrent PB " + "or set pb=None with constant_pb=True for tree DAGs." + ) + def sample_trajectories( self, env: Env, @@ -275,20 +300,28 @@ class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]): """ def __init__( - self, pf: Estimator, pb: Estimator | None, constant_pb: bool = False + self, + pf: Estimator, + pb: Estimator | None, + constant_pb: bool = False, ) -> None: """Initializes a TrajectoryBasedGFlowNet instance. Args: pf: The forward policy estimator. - pb: The backward policy estimator, or None if the gflownet DAG is a tree, and - pb is therefore always 1. + pb: The backward policy estimator, or None if the gflownet DAG is a tree, + and pb is therefore always 1. constant_pb: Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + """ - super().__init__(pf, pb, constant_pb=constant_pb) + super().__init__( + pf, + pb, + constant_pb=constant_pb, + ) def get_pfs_and_pbs( self, @@ -301,8 +334,8 @@ def get_pfs_and_pbs( More specifically, it evaluates $\log P_F(s' \mid s)$ and $\log P_B(s \mid s')$ for each transition in each trajectory in the batch. - If recalculate_all_logprobs=True, we re-evaluate the logprobs of the trajectories - using the current self.pf. Otherwise, the following applies: + If recalculate_all_logprobs=True, we re-evaluate the logprobs of the + trajectories using the current self.pf. Otherwise, the following applies: - If trajectories have logprobs attribute, use them - this is usually for on-policy learning. - Elif trajectories have estimator_outputs attribute, transform them into @@ -322,7 +355,11 @@ def get_pfs_and_pbs( the log_pf and log_pb for each action in each trajectory. """ return get_trajectory_pfs_and_pbs( - self.pf, self.pb, trajectories, fill_value, recalculate_all_logprobs + self.pf, + self.pb, + trajectories, + fill_value, + recalculate_all_logprobs, ) def get_scores( diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index efb53d1b..8879c79a 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -93,12 +93,23 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + """ super().__init__(pf, pb, constant_pb=constant_pb) + + # Disallow recurrent PF for transition-based DB + from gfn.estimators import RecurrentDiscretePolicyEstimator # type: ignore + + if isinstance(self.pf, RecurrentDiscretePolicyEstimator): + raise TypeError( + "DBGFlowNet does not support recurrent PF estimators (transitions path cannot propagate carry)." + ) + assert any( isinstance(logF, cls) for cls in [ScalarEstimator, ConditionalScalarEstimator] ), "logF must be a ScalarEstimator or derived" + self.logF = logF self.forward_looking = forward_looking self.log_reward_clip_min = log_reward_clip_min @@ -148,7 +159,10 @@ def get_pfs_and_pbs( log_pb for each transition. """ return get_transition_pfs_and_pbs( - self.pf, self.pb, transitions, recalculate_all_logprobs + self.pf, + self.pb, + transitions, + recalculate_all_logprobs, ) def get_scores( @@ -301,9 +315,19 @@ class ModifiedDBGFlowNet(PFBasedGFlowNet[Transitions]): """ def __init__( - self, pf: Estimator, pb: Estimator | None, constant_pb: bool = False + self, + pf: Estimator, + pb: Estimator | None, + constant_pb: bool = False, ) -> None: - """Initializes a ModifiedDBGFlowNet instance.""" + """Initializes a ModifiedDBGFlowNet instance. + + Args: + pf: Forward policy estimator. + pb: Backward policy estimator or None. + constant_pb: See base class. + + """ super().__init__(pf, pb, constant_pb=constant_pb) def get_scores( diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 79cc9a52..d4c8bb5e 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -5,7 +5,10 @@ from gfn.containers import StatesContainer, Trajectories from gfn.env import DiscreteEnv -from gfn.estimators import ConditionalDiscretePolicyEstimator, DiscretePolicyEstimator +from gfn.estimators import ( + DiscretePolicyEstimator, + PolicyMixin, +) from gfn.gflownet.base import GFlowNet, loss_reduce from gfn.samplers import Sampler from gfn.states import DiscreteStates @@ -32,6 +35,10 @@ class FMGFlowNet(GFlowNet[StatesContainer[DiscreteStates]]): logF: A DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator for estimating the log flow of the edges (states -> next_states). alpha: A scalar weight for the reward matching loss. + + Flow Matching does not rely on PF/PB probability recomputation. Any trajectory + sampling provided by this class is for diagnostics/visualization and can only use + the default (non-recurrent) PolicyMixin interface. """ def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): @@ -43,11 +50,10 @@ def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): alpha: A scalar weight for the reward matching loss. """ super().__init__() - assert isinstance( - logF, - DiscretePolicyEstimator | ConditionalDiscretePolicyEstimator, - ), "logF must be a DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator" + logF, PolicyMixin + ), "logF must use the default PolicyMixin interface" + self.logF = logF self.alpha = alpha diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index 34175654..bcb1f919 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -100,6 +100,7 @@ def __init__( gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. + """ super().__init__(pf, pb, constant_pb=constant_pb) assert any( @@ -157,14 +158,14 @@ def cumulative_logprobs( def calculate_preds( self, - log_pf_trajectories_cum: CumulativeLogProbsTensor, + log_pf_traj_cum: CumulativeLogProbsTensor, log_state_flows: LogStateFlowsTensor, i: int, ) -> PredictionsTensor: """Calculates the predictions tensor for the current sub-trajectory length. Args: - log_pf_trajectories_cum: Tensor of shape (max_length + 1, n_trajectories) + log_pf_traj_cum: Tensor of shape (max_length + 1, n_trajectories) containing the cumulative sum of logprobs of the forward actions for each trajectory. log_state_flows: Tensor of shape (max_length, n_trajectories) containing @@ -178,11 +179,7 @@ def calculate_preds( log_state_flows if i == 1 else log_state_flows[: -(i - 1)] ) - preds = ( - log_pf_trajectories_cum[i:] - - log_pf_trajectories_cum[:-i] - + current_log_state_flows - ) + preds = log_pf_traj_cum[i:] - log_pf_traj_cum[:-i] + current_log_state_flows return preds @@ -190,7 +187,7 @@ def calculate_targets( self, trajectories: Trajectories, preds: PredictionsTensor, - log_pb_trajectories_cum: CumulativeLogProbsTensor, + log_pb_traj_cum: CumulativeLogProbsTensor, log_state_flows: LogStateFlowsTensor, is_terminal_mask: MaskTensor, sink_states_mask: MaskTensor, @@ -202,7 +199,7 @@ def calculate_targets( trajectories: The batch of trajectories. preds: Tensor of shape (max_length + 1 - i, n_trajectories) containing the predictions for the current sub-trajectory length. - log_pb_trajectories_cum: Tensor of shape (max_length + 1, n_trajectories) + log_pb_traj_cum: Tensor of shape (max_length + 1, n_trajectories) containing the cumulative sum of logprobs of the backward actions for each trajectory. log_state_flows: Tensor of shape (max_length, n_trajectories) containing @@ -229,15 +226,16 @@ def calculate_targets( # We need to add to that the log-probabilities of the backward actions up-to # the sub-trajectory's terminating state if i > 1: - targets[is_terminal_mask[i - 1 :]] += ( - log_pb_trajectories_cum[i - 1 :] - log_pb_trajectories_cum[: -i + 1] - )[:-1][is_terminal_mask[i - 1 :]] + delta_pb = (log_pb_traj_cum[i - 1 :] - log_pb_traj_cum[: -i + 1])[:-1] + targets[is_terminal_mask[i - 1 :]] += delta_pb[is_terminal_mask[i - 1 :]] # The following creates the targets for the non-finishing sub-trajectories full_mask = sink_states_mask | is_terminal_mask + delta_pb2 = (log_pb_traj_cum[i:] - log_pb_traj_cum[:-i])[:-1] + rhs_mask = ~full_mask[i - 1 : -1] targets[~full_mask[i - 1 :]] = ( - log_pb_trajectories_cum[i:] - log_pb_trajectories_cum[:-i] - )[:-1][~full_mask[i - 1 : -1]] + log_state_flows[i:][~sink_states_mask[i:]] + delta_pb2[rhs_mask] + log_state_flows[i:][~sink_states_mask[i:]] + ) return targets @@ -445,11 +443,8 @@ def get_tb_contributions(self, trajectories: Trajectories) -> ContributionsTenso contributions = torch.zeros(n_rows, len(trajectories)) # Each trajectory contributes one element to the loss, equally weighted - terminating_idx = trajectories.terminating_idx - indices = ( - max_len * (terminating_idx - 1) - - (terminating_idx - 1) * (terminating_idx - 2) / 2 - ).long() + t_idx = trajectories.terminating_idx + indices = (max_len * (t_idx - 1) - (t_idx - 1) * (t_idx - 2) / 2).long() contributions.scatter_(0, indices.unsqueeze(0), 1) contributions = contributions / len(trajectories) @@ -498,16 +493,16 @@ def get_geometric_within_contributions( """ L = self.lamda max_len = trajectories.max_length - terminating_idx = trajectories.terminating_idx + t_idx = trajectories.terminating_idx # The following tensor represents the weights given to each possible # sub-trajectory length. - contributions = ( - L ** torch.arange(max_len, device=terminating_idx.device).double() - ).to(torch.get_default_dtype()) + contributions = (L ** torch.arange(max_len, device=t_idx.device).double()).to( + torch.get_default_dtype() + ) contributions = contributions.unsqueeze(-1).repeat(1, len(trajectories)) contributions = contributions.repeat_interleave( - torch.arange(max_len, 0, -1, device=terminating_idx.device), + torch.arange(max_len, 0, -1, device=t_idx.device), dim=0, output_size=int(max_len * (max_len + 1) / 2), ) @@ -519,10 +514,7 @@ def get_geometric_within_contributions( per_trajectory_denom = ( 1.0 / (1 - L) ** 2 - * ( - L * (L ** terminating_idx.double() - 1) - + (1 - L) * terminating_idx.double() - ) + * (L * (L ** t_idx.double() - 1) + (1 - L) * t_idx.double()) ).to(torch.get_default_dtype()) contributions = contributions / per_trajectory_denom / len(trajectories) diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index 5b50c7ca..57933bbf 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -8,7 +8,7 @@ from torch import Size, Tensor from torch.distributions import Beta, Categorical, Distribution, MixtureSameFamily -from gfn.estimators import Estimator +from gfn.estimators import Estimator, PolicyMixin from gfn.gym import Box from gfn.states import States from gfn.utils.modules import MLP @@ -936,7 +936,7 @@ def split_PF_module_output( return (exit_probability, mixture_logits, alpha_theta, beta_theta, alpha_r, beta_r) -class BoxPFEstimator(Estimator): +class BoxPFEstimator(Estimator, PolicyMixin): r"""Estimator for `P_F` for the Box environment. This estimator uses the `DistributionWrapper` distribution. @@ -1060,7 +1060,7 @@ def _normalize(x: Tensor) -> Tensor: ) -class BoxPBEstimator(Estimator): +class BoxPBEstimator(Estimator, PolicyMixin): r"""Estimator for `P_B` for the Box environment. This estimator uses the `QuarterCircle(northeastern=False)` distribution. diff --git a/src/gfn/preprocessors.py b/src/gfn/preprocessors.py index 20d63359..b0fbadd2 100644 --- a/src/gfn/preprocessors.py +++ b/src/gfn/preprocessors.py @@ -23,15 +23,21 @@ class Preprocessor(ABC): dimension will not be checked. """ - def __init__(self, output_dim: int | None) -> None: + def __init__( + self, output_dim: int | None, target_dtype: torch.dtype | None = None + ) -> None: """Initializes a Preprocessor with the specified output dimension. Args: output_dim: The dimensionality of the preprocessed output tensor, which is compatible with the neural network that will be used. If None, the output dimension will not be checked. + target_dtype: Optional dtype to cast tensor outputs to. When set, any + tensor returned by `preprocess` will be cast to this dtype in + `__call__` before returning. """ self.output_dim = output_dim + self.target_dtype = target_dtype @abstractmethod def preprocess(self, states: States) -> torch.Tensor: @@ -55,8 +61,11 @@ def __call__(self, states: States | GraphStates) -> torch.Tensor | GeometricBatc The preprocessed states as a tensor or GeometricBatch. """ out = self.preprocess(states) - if isinstance(out, torch.Tensor) and self.output_dim is not None: - assert out.shape[-1] == self.output_dim + if isinstance(out, torch.Tensor): + if self.output_dim is not None: + assert out.shape[-1] == self.output_dim + if self.target_dtype is not None and out.dtype != self.target_dtype: + out = out.to(self.target_dtype) return out diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index bb6d48b6..686ef924 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -1,41 +1,33 @@ -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, cast import torch from gfn.actions import Actions from gfn.containers import Trajectories from gfn.env import Env -from gfn.estimators import Estimator +from gfn.estimators import Estimator, PolicyEstimatorProtocol from gfn.states import GraphStates, States from gfn.utils.common import ensure_same_device from gfn.utils.graphs import graph_states_share_storage -from gfn.utils.handlers import ( - has_conditioning_exception_handler, - no_conditioning_exception_handler, -) from gfn.utils.prob_calculations import get_trajectory_pbs, get_trajectory_pfs class Sampler: - """Wrapper for a PolicyEstimator that enables sampling from GFlowNet environments. + """Estimator‑driven sampler for GFlowNet environments. - A Sampler encapsulates a PolicyEstimator and provides methods to sample individual - actions or complete trajectories from GFlowNet environments. It can be used for - both forward and backward sampling, depending on the estimator's configuration. + The estimator builds action distributions, computes step log‑probs, and records + artifacts into a rollout context via method flags. Direction (forward/backward) + is determined by ``estimator.is_backward``. Attributes: - estimator: The PolicyEstimator used for sampling actions and computing - probability distributions. + estimator: The underlying policy estimator. Must expose the methods contained + in the `PolicyMixin` mixin. """ def __init__(self, estimator: Estimator) -> None: - """Initializes a Sampler with a PolicyEstimator. - - Args: - estimator: The PolicyEstimator to use for sampling actions and computing - probability distributions. - """ + """Initializes a Sampler with a PolicyEstimator.""" self.estimator = estimator + # TODO: Assert that the estimator exposes the methods contained in the `PolicyMixin` mixin. def sample_actions( self, @@ -44,73 +36,95 @@ def sample_actions( conditioning: torch.Tensor | None = None, save_estimator_outputs: bool = False, save_logprobs: bool = False, + ctx: Any | None = None, **policy_kwargs: Any, ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]: - """Samples actions from the given states using the policy estimator. - - This method samples actions from the probability distribution defined by the - policy estimator. + """Sample one step from ``states`` via the estimator. - When sampling off-policy, ensure to set `save_logprobs=False`. Log probabilities - for off-policy actions should be calculated separately during GFlowNet training. + Initializes or reuses a rollout context with ``estimator.init_context``, + builds a Distribution with ``estimator.compute_dist``, and optionally computes + log‑probs with ``estimator.log_probs``. Per‑step artifacts are recorded by + the estimator when the corresponding flags are set. Args: - env: The environment where the states and actions are defined. - states: A batch of states to sample actions from. - conditioning: Optional tensor of conditioning information for conditional - policies. If provided, the estimator must support conditional sampling. - save_estimator_outputs: If True, returns the raw outputs from the estimator - before conversion to probability distributions. This is useful for - off-policy training with tempered policies. - save_logprobs: If True, calculates and returns the log probabilities of - the sampled actions under the policy distribution. This is useful for - on-policy training. - **policy_kwargs: Keyword arguments passed to the estimator's - `to_probability_distribution` method. Common parameters include: - - `temperature`: Scalar to divide logits by before softmax - - `epsilon`: Probability of choosing random actions (exploration) - - `sf_bias`: Bias to apply to exit action logits + env: Environment providing action/state conversion utilities. + states: Batch of states to act on. + conditioning: Optional conditioning for conditional policies. + save_estimator_outputs: If True, return the raw estimator outputs + cached by the PolicyMixin for this step. Useful for off-policy training + with tempered policies. + save_logprobs: If True, return per‑step log‑probs padded to batch. + Useful for on-policy training. + **policy_kwargs: Extra kwargs forwarded to + ``to_probability_distribution``. Returns: - A tuple containing: - - An Actions object with the sampled actions - - Optional tensor of log probabilities (if save_logprobs=True) - - Optional tensor of estimator outputs (if save_estimator_outputs=True) + ``(Actions, log_probs | None, estimator_outputs | None)``. The + estimator outputs come from + ``PolicyMixin.get_current_estimator_output(ctx)`` when requested. """ - # TODO: Should estimators instead ignore None for the conditioning vector? - if conditioning is not None: - with has_conditioning_exception_handler("estimator", self.estimator): - estimator_output = self.estimator(states, conditioning) - else: - with no_conditioning_exception_handler("estimator", self.estimator): - estimator_output = self.estimator(states) + # NOTE: Explicitly cast to the policy protocol so static analyzers know + # the estimator exposes the mixin methods (init_context/compute_dist/log_probs). + policy_estimator = cast(PolicyEstimatorProtocol, self.estimator) + # Runtime guard: ensure the estimator actually implements the required protocol methods. + # This keeps helpful error messages when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_estimator, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + + if ctx is None: + ctx = policy_estimator.init_context( + batch_size=states.batch_shape[0], + device=states.device, + conditioning=conditioning, + ) - dist = self.estimator.to_probability_distribution( - states, estimator_output, **policy_kwargs + step_mask = torch.ones( + states.batch_shape[0], dtype=torch.bool, device=states.device + ) + dist, ctx = policy_estimator.compute_dist( + states, + ctx, + step_mask, + save_estimator_outputs=save_estimator_outputs, + **policy_kwargs, ) with torch.no_grad(): - actions = dist.sample() + actions_tensor = dist.sample() if save_logprobs: - log_probs = dist.log_prob(actions) - if torch.any(torch.isinf(log_probs)): - raise RuntimeError("Log probabilities are inf. This should not happen.") + # Use estimator to compute step log-probs and pad to batch. + log_probs, ctx = policy_estimator.log_probs( + actions_tensor, + dist, + ctx, + step_mask, + vectorized=False, + save_logprobs=True, + ) else: log_probs = None - actions = env.actions_from_tensor(actions) + actions = env.actions_from_tensor(actions_tensor) - if not save_estimator_outputs: - estimator_output = None + estimator_output = None + if save_estimator_outputs: + if not hasattr(policy_estimator, "get_current_estimator_output"): + raise TypeError( + "Estimator does not support get_current_estimator_output and save_estimator_outputs is True!" + ) + estimator_output = policy_estimator.get_current_estimator_output(ctx) + assert estimator_output is not None assert log_probs is None or log_probs.shape == actions.batch_shape - # assert estimator_output is None or estimator_output.shape == actions.batch_shape - # TODO: check expected shape return actions, log_probs, estimator_output - def sample_trajectories( + # TODO: How to avoid "Sampler.sample_trajectories' is too complex" error? + def sample_trajectories( # noqa: C901 self, env: Env, n: Optional[int] = None, @@ -120,40 +134,41 @@ def sample_trajectories( save_logprobs: bool = False, **policy_kwargs: Any, ) -> Trajectories: - """Samples complete trajectories from the environment. + """Roll out complete trajectories using the estimator. - This method samples trajectories by sequentially sampling actions from the - policy estimator. It supports both forward and backward sampling, depending on - the estimator's `is_backward` flag. If forward sampling, it samples until all - trajectories reach the sink state. If backward sampling, it samples until all - trajectories reach the initial state. + Reuses a single rollout context across steps, calling + ``compute_dist`` & ``log_probs`` each iteration. Uses + ``estimator.is_backward`` to choose the environment step function. Args: - env: The environment to sample trajectories from. - n: Number of trajectories to sample, all starting from s0. Must be - provided if `states` is None. - states: Initial states to start trajectories from. It should have batch_shape - of length 1 (no trajectory dim). If `None`, `n` must be provided and we - initialize `n` trajectories with the environment's initial state. - conditioning: Optional tensor of conditioning information for conditional - policies. Must match the batch shape of states. - save_estimator_outputs: If True, saves the estimator outputs for each - step. Useful for off-policy training with tempered policies. - save_logprobs: If True, calculates and saves the log probabilities of - sampled actions. Useful for on-policy training. - **policy_kwargs: Keyword arguments passed to the policy estimator. - See `sample_actions` for details. + env: Environment to sample in. + n: Number of trajectories if ``states`` is None. + states: Starting states (batch shape length 1) or ``None``. + conditioning: Optional conditioning aligned with the batch. + save_estimator_outputs: If True, store per‑step estimator outputs. Useful + for off-policy training with tempered policies. + save_logprobs: If True, store per‑step log‑probs. Useful for on-policy + training. + **policy_kwargs: Extra kwargs forwarded to the policy. Returns: - A Trajectories object containing the sampled trajectories with batch_shape - (max_length+1, n_trajectories) for states and (max_length, n_trajectories) - for actions. + A ``Trajectories`` with stacked states/actions and any artifacts. Note: For backward trajectories, the reward is computed at the initial state (s0) rather than the terminal state (sf). """ - if self.estimator.is_backward: + # NOTE: Cast to the policy protocol for static typing across mixin methods/properties. + policy_estimator = cast(PolicyEstimatorProtocol, self.estimator) + # Runtime guard: ensure the estimator actually implements the required protocol + # method and raises an error when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_estimator, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + + if policy_estimator.is_backward: # [ASSUMPTION] When backward sampling, all provided states are the # terminating states (can be passed to log_reward fn) assert ( @@ -177,72 +192,63 @@ def sample_trajectories( assert states.batch_shape == conditioning.shape[: len(states.batch_shape)] ensure_same_device(states.device, conditioning.device) - dones = ( - states.is_initial_state - if self.estimator.is_backward - else states.is_sink_state - ) + if policy_estimator.is_backward: + dones = states.is_initial_state + else: + dones = states.is_sink_state # Define dummy actions to avoid errors when stacking empty lists. trajectories_states: List[States] = [states] trajectories_actions: List[Actions] = [ env.actions_from_batch_shape((n_trajectories,)) ] - trajectories_logprobs: List[torch.Tensor] = [ - torch.full((n_trajectories,), fill_value=0, device=device) - ] + # Placeholder kept for backward-compatibility of shapes; logprobs are + # recorded and stacked by the estimator via the context. trajectories_terminating_idx = torch.zeros( n_trajectories, dtype=torch.long, device=device ) step = 0 - all_estimator_outputs = [] + if not hasattr(policy_estimator, "init_context"): + raise TypeError("Estimator is not policy-capable (missing PolicyMixin)") + ctx = policy_estimator.init_context(n_trajectories, device, conditioning) while not all(dones): actions = env.actions_from_batch_shape((n_trajectories,)) - log_probs = torch.full((n_trajectories,), fill_value=0.0, device=device) - # This optionally allows you to retrieve the estimator_outputs collected - # during sampling. This is useful if, for example, you want to evaluate off - # policy actions later without repeating calculations to obtain the env - # distribution parameters. - if conditioning is not None: - masked_conditioning = conditioning[~dones] - else: - masked_conditioning = None - - valid_actions, actions_log_probs, estimator_outputs = self.sample_actions( - env, - states[~dones], - masked_conditioning, - save_estimator_outputs=True if save_estimator_outputs else False, - save_logprobs=save_logprobs, + step_mask = ~dones + + # Compute distribution on active rows + dist, ctx = policy_estimator.compute_dist( + states[step_mask], + ctx, + step_mask, + save_estimator_outputs=save_estimator_outputs, **policy_kwargs, ) - if estimator_outputs is not None: - # Place estimator outputs into a stackable tensor. Note that this - # will be replaced with torch.nested.nested_tensor in the future. - estimator_outputs_padded = torch.full( - (n_trajectories,) + estimator_outputs.shape[1:], - fill_value=-float("inf"), - device=device, - ) - estimator_outputs_padded[~dones] = estimator_outputs - all_estimator_outputs.append(estimator_outputs_padded) - actions[~dones] = valid_actions + # Sample actions for active rows + with torch.no_grad(): + valid_actions_tensor = dist.sample() + valid_actions = env.actions_from_tensor(valid_actions_tensor) + if save_logprobs: - assert ( - actions_log_probs is not None - ), "actions_log_probs should not be None when save_logprobs is True" - log_probs[~dones] = actions_log_probs + # Use estimator to compute step log-probs and pad to batch (recorded in ctx). + _, ctx = policy_estimator.log_probs( + valid_actions_tensor, + dist, + ctx, + step_mask, + vectorized=False, + save_logprobs=True, + ) + actions[step_mask] = valid_actions trajectories_actions.append(actions) - trajectories_logprobs.append(log_probs) - if self.estimator.is_backward: - new_states = env._backward_step(states, actions) + if policy_estimator.is_backward: + new_states = env._backward_step(states, actions) # type: ignore[attr-defined] else: - new_states = env._step(states, actions) + new_states = env._step(states, actions) # type: ignore[attr-defined] # Ensure that the new state is a distinct object from the old state. assert new_states is not states @@ -265,7 +271,7 @@ def sample_trajectories( # to filter out the already done ones. new_dones = ( new_states.is_initial_state - if self.estimator.is_backward + if policy_estimator.is_backward else new_states.is_sink_state ) & ~dones trajectories_terminating_idx[new_dones] = step @@ -274,36 +280,40 @@ def sample_trajectories( dones = dones | new_dones trajectories_states.append(states) - # Stack all states and actions + # Stack all states and actions. stacked_states = env.States.stack(trajectories_states) - stacked_actions = env.Actions.stack(trajectories_actions)[ - 1: - ] # Drop dummy action + + # Stack actions, drop dummy action. + stacked_actions = env.Actions.stack(trajectories_actions)[1:] + + # Get trajectory artifacts from the context (already shaped (T, N, ...)) stacked_logprobs = ( - torch.stack(trajectories_logprobs, dim=0)[1:] # Drop dummy logprob - if save_logprobs + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs else None ) - - # TODO: use torch.nested.nested_tensor(dtype, device, requires_grad). stacked_estimator_outputs = ( - torch.stack(all_estimator_outputs, dim=0) if save_estimator_outputs else None + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None ) - # If there are no logprobs or estimator outputs, set them to None. - # TODO: This is a hack to avoid errors when no logprobs or estimator outputs are - # saved. This bug was introduced when I changed the dtypes library-wide -- why - # is this happening? - if stacked_logprobs is not None and len(stacked_logprobs) == 0: - stacked_logprobs = None - if stacked_estimator_outputs is not None and len(stacked_estimator_outputs) == 0: - stacked_estimator_outputs = None + # Stacked logprobs and estimator outputs are only None if there are no + # valid trajectories. + if stacked_logprobs is not None: + if len(stacked_logprobs) == 0: + stacked_logprobs = None + + if stacked_estimator_outputs is not None: + if len(stacked_estimator_outputs) == 0: + stacked_estimator_outputs = None # Broadcast conditioning tensor to match states batch shape if needed if conditioning is not None: - # The states have batch shape (max_length, n_trajectories) - # The conditioning tensor should have shape (n_trajectories,) or (n_trajectories, 1) - # We need to broadcast it to (max_length, n_trajectories, 1) for the estimator + # The states have batch shape (max_length, n_trajectories). The + # conditioning tensor should have shape (n_trajectories,) or + # (n_trajectories, 1). We need to broadcast it to (max_length, + # n_trajectories, 1) for the estimator if len(conditioning.shape) == 1: # conditioning has shape (n_trajectories,) conditioning = ( @@ -323,7 +333,7 @@ def sample_trajectories( conditioning=conditioning, actions=stacked_actions, terminating_idx=trajectories_terminating_idx, - is_backward=self.estimator.is_backward, + is_backward=policy_estimator.is_backward, log_rewards=None, # will be calculated later log_probs=stacked_logprobs, estimator_outputs=stacked_estimator_outputs, @@ -354,7 +364,7 @@ def __init__( self, pf_estimator: Estimator, pb_estimator: Estimator, - ): + ) -> None: """Initializes a LocalSearchSampler with forward and backward estimators. Args: diff --git a/src/gfn/states.py b/src/gfn/states.py index a8bcc178..8abb4a3a 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -233,7 +233,14 @@ def __setitem__( index: Indices to set. states: States object containing the new states. """ - self.tensor[index] = states.tensor + # Align dtype/device of the source to the destination slice to avoid + # runtime errors from mismatched tensor properties during indexed writes. + # Note: we intentionally do not mutate `states.tensor` in-place. + dest = self.tensor + src = states.tensor + if src.dtype != dest.dtype or src.device != dest.device: + src = src.to(device=dest.device, dtype=dest.dtype) + self.tensor[index] = src def clone(self) -> States: """Returns a clone of the current instance. @@ -522,7 +529,9 @@ def __getitem__( return out def __setitem__( - self, index: int | Sequence[int] | Sequence[bool], states: DiscreteStates + self, + index: int | Sequence[int] | Sequence[bool] | torch.Tensor, + states: DiscreteStates, ) -> None: """Sets particular discrete states and their masks. @@ -672,6 +681,58 @@ def init_forward_masks(self, set_ones: bool = True) -> None: self.forward_masks = torch.zeros(shape).to(self.device).bool() +class ChunkedStates(DiscreteStates): + """Reusable ChunkedStates base used by chunking-aware environments. + + Env factories should return a subclass that binds env-specific class variables + (state_shape, s0, sf, n_actions, device) and two hooks: + - get_n_actions: Callable[[], int] returning current env.n_actions + - overlay_masks: Callable[[ChunkedStates], None] applying env overlays + """ + + # Hooks to be provided by the environment-specific subclass + get_n_actions: ClassVar[Callable[[], int]] = lambda: cast( + int, ChunkedStates.n_actions + ) + overlay_masks: ClassVar[Optional[Callable[["ChunkedStates"], None]]] = None + + def __init__(self, tensor, forward_masks=None, backward_masks=None): + super().__init__(tensor, forward_masks, backward_masks) + self._ensure_current() + + def _ensure_current(self) -> None: + # Keep class-level n_actions in sync with env via hook if available + try: + self.__class__.n_actions = int(self.__class__.get_n_actions()) + except Exception: + pass + + if (self.forward_masks.shape[-1] != self.__class__.n_actions) or ( + self.backward_masks.shape[-1] != self.__class__.n_actions - 1 + ): + self.forward_masks = torch.ones( + (*self.batch_shape, self.__class__.n_actions), + dtype=torch.bool, + device=self.device, + ) + self.backward_masks = torch.ones( + (*self.batch_shape, self.__class__.n_actions - 1), + dtype=torch.bool, + device=self.device, + ) + + if self.__class__.overlay_masks is not None: + self.__class__.overlay_masks(self) + + def pad_dim0_with_sf(self, required_first_dim: int) -> None: + super().pad_dim0_with_sf(required_first_dim) + self._ensure_current() + + def extend(self, other: "ChunkedStates") -> None: + super().extend(other) + self._ensure_current() + + class GraphStates(States): """Base class for graph-based state representations. diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index fddebb69..5ca9cc14 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -1,12 +1,15 @@ """This file contains some examples of modules that can be used with GFN.""" import math +from abc import ABC, abstractmethod from typing import Literal, Optional import torch import torch.nn as nn +import torch.nn.functional as F from linear_attention_transformer import LinearAttentionTransformer from tensordict import TensorDict +from torch import Tensor from torch_geometric.nn import DirGNNConv, GCNConv, GINConv from gfn.actions import GraphActions, GraphActionType @@ -962,3 +965,539 @@ def bias(self) -> torch.Tensor | None: return self.bias_mu else: return None + + +class AutoregressiveDiscreteSequenceModel(ABC, nn.Module): + + @abstractmethod + def init_carry( + self, + batch_size: int, + device: torch.device, + ) -> dict[str, torch.Tensor]: + """Initialize the carry for the sequence model. + + Args: + batch_size (int): Batch size. + device (torch.device): Device to allocate carry tensors on. + + Returns: + dict[str, torch.Tensor]: Initialized carry. + """ + + @abstractmethod + def forward( + self, + x: torch.Tensor, + carry: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Compute the logits for the next tokens in the sequence. + + Args: + x (torch.Tensor): (B, T) tensor of input token indices where ``T`` is the + number of newly supplied timesteps (``T`` may be 1 for incremental + decoding). + carry (dict[str, torch.Tensor]): Carry from previous steps for recurrent + processing (e.g., hidden states). + + Returns: + tuple[torch.Tensor, dict[str, torch.Tensor]]: Logits for the next token + at each supplied timestep with shape (B, T, vocab) and updated carry. + """ + + @property + @abstractmethod + def vocab_size(self) -> int: + """Size of the vocabulary (excluding BOS token).""" + + +class RecurrentDiscreteSequenceModel(AutoregressiveDiscreteSequenceModel): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + hidden_size: int, + num_layers: int = 1, + rnn_type: Literal["lstm", "gru"] = "lstm", + dropout: float = 0.0, + ) -> None: + super().__init__() + if num_layers <= 0: + raise ValueError("num_layers must be a positive integer.") + rnn_kind = rnn_type.lower() + if rnn_kind not in {"lstm", "gru"}: + raise ValueError("rnn_type must be 'lstm' or 'gru'.") + + if not 0.0 <= dropout <= 1.0: + raise ValueError("dropout must be in the range [0, 1].") + + self._vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.hidden_size = hidden_size + self.num_layers = num_layers + self.rnn_type = rnn_kind + + self.embedding = nn.Embedding(vocab_size + 1, embedding_dim) # +1 for BOS token + rnn_dropout = dropout if num_layers > 1 else 0.0 + self.lstm: nn.LSTM | None + self.gru: nn.GRU | None + if rnn_kind == "lstm": + self.lstm = nn.LSTM( + input_size=embedding_dim, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=rnn_dropout, + ) + self.gru = None + else: + self.gru = nn.GRU( + input_size=embedding_dim, + hidden_size=hidden_size, + num_layers=num_layers, + batch_first=True, + dropout=rnn_dropout, + ) + self.lstm = None + self.output_projection = nn.Linear(hidden_size, vocab_size) + + def init_carry( + self, + batch_size: int, + device: torch.device, + ) -> dict[str, torch.Tensor]: + carry: dict[str, torch.Tensor] = { + "hidden": torch.zeros( + self.num_layers, batch_size, self.hidden_size, device=device + ), + } + if self.rnn_type == "lstm": + carry["cell"] = torch.zeros( + self.num_layers, batch_size, self.hidden_size, device=device + ) + return carry + + def forward( + self, + x: torch.Tensor, + carry: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + if x.dim() != 2: + raise ValueError("Expected input tensor with shape (batch, timesteps).") + + batch, timesteps = x.size() + device = x.device + + if "hidden" not in carry: + raise KeyError("Carry must provide a 'hidden' state tensor.") + + hidden = carry["hidden"] + if hidden.size(1) != batch: + raise ValueError( + "Hidden state batch dimension does not match the provided tokens." + ) + if hidden.device != device: + raise ValueError( + "Hidden state tensor must live on the same device as input tokens." + ) + + embedded = self.embedding(x) + + if self.rnn_type == "lstm": + lstm = self.lstm + if lstm is None: + raise RuntimeError("LSTM module was not initialized.") + if "cell" not in carry: + raise KeyError("LSTM carry must provide a 'cell' state tensor.") + cell = carry["cell"] + if cell.size(1) != batch: + raise ValueError( + "Cell state batch dimension does not match the provided tokens." + ) + if cell.device != device: + raise ValueError( + "Cell state tensor must live on the same device as input tokens." + ) + outputs, (hidden_next, cell_next) = lstm(embedded, (hidden, cell)) + updated_carry: dict[str, torch.Tensor] = { + "hidden": hidden_next, + "cell": cell_next, + } + else: + gru = self.gru + if gru is None: + raise RuntimeError("GRU module was not initialized.") + outputs, hidden_next = gru(embedded, hidden) + updated_carry = { + "hidden": hidden_next, + } + + logits = self.output_projection(outputs) + return logits, updated_carry + + @property + def vocab_size(self) -> int: + return self._vocab_size + + +class _AutoregressiveTransformerBlock(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + ff_hidden_dim: int, + dropout: float, + ) -> None: + super().__init__() + if embed_dim % num_heads != 0: + raise ValueError("Embedding dimension must be divisible by number of heads.") + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + + self.linear1 = nn.Linear(embed_dim, ff_hidden_dim) + self.linear2 = nn.Linear(ff_hidden_dim, embed_dim) + + self.attn_dropout = nn.Dropout(dropout) + self.residual_dropout = nn.Dropout(dropout) + self.ff_dropout = nn.Dropout(dropout) + + def forward( + self, + hidden: torch.Tensor, + key_carry: torch.Tensor, + value_carry: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch, timesteps, _ = hidden.size() + + normed_hidden = self.norm1(hidden) + + q = self.q_proj(normed_hidden) + k = self.k_proj(normed_hidden) + v = self.v_proj(normed_hidden) + + q = q.view(batch, timesteps, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch, timesteps, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch, timesteps, self.num_heads, self.head_dim).transpose(1, 2) + + carry_length = key_carry.size(2) + updated_key_carry = torch.cat((key_carry, k), dim=2) + updated_value_carry = torch.cat((value_carry, v), dim=2) + + attn_scores = torch.matmul(q, updated_key_carry.transpose(-2, -1)) / math.sqrt( + float(self.head_dim) + ) + + if timesteps > 1 or carry_length > 0: + total_kv_length = carry_length + timesteps + kv_positions = torch.arange( + total_kv_length, device=hidden.device, dtype=torch.long + ) + query_positions = torch.arange( + timesteps, device=hidden.device, dtype=torch.long + ).unsqueeze(1) + causal_mask = kv_positions.unsqueeze(0) <= (query_positions + carry_length) + attn_scores = attn_scores.masked_fill( + ~causal_mask.unsqueeze(0).unsqueeze(0), float("-inf") + ) + + attn_weights = torch.softmax(attn_scores, dim=-1) + attn_weights = self.attn_dropout(attn_weights) + attn_output = torch.matmul(attn_weights, updated_value_carry) + attn_output = attn_output.transpose(1, 2).reshape( + batch, timesteps, self.embed_dim + ) + attn_output = self.out_proj(attn_output) + + residual = hidden + hidden = residual + self.residual_dropout(attn_output) + + ff_input = self.norm2(hidden) + ff_hidden = self.linear1(ff_input) + ff_hidden = self.ff_dropout(F.gelu(ff_hidden)) + ff_hidden = self.linear2(ff_hidden) + + hidden = hidden + self.residual_dropout(ff_hidden) + return hidden, updated_key_carry, updated_value_carry + + +class TransformerDiscreteSequenceModel(AutoregressiveDiscreteSequenceModel): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + num_heads: int, + ff_hidden_dim: int, + num_layers: int, + max_position_embeddings: int, + dropout: float = 0.0, + positional_embedding: Literal["learned", "sinusoidal"] = "learned", + ) -> None: + super().__init__() + if num_layers <= 0: + raise ValueError("num_layers must be positive.") + if max_position_embeddings <= 0: + raise ValueError("max_position_embeddings must be positive.") + if not 0.0 <= dropout <= 1.0: + raise ValueError("dropout must lie in [0, 1].") + if embedding_dim % num_heads != 0: + raise ValueError("embedding_dim must be divisible by num_heads.") + if positional_embedding not in {"learned", "sinusoidal"}: + raise ValueError("positional_embedding must be 'learned' or 'sinusoidal'.") + + self._vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.ff_hidden_dim = ff_hidden_dim + self.num_layers = num_layers + self.max_position_embeddings = max_position_embeddings + self.head_dim = embedding_dim // num_heads + self._positional_embedding_type = positional_embedding + + self.token_embedding = nn.Embedding( + vocab_size + 1, embedding_dim + ) # +1 for BOS token + if self._positional_embedding_type == "learned": + self.position_embedding = nn.Embedding( + max_position_embeddings, embedding_dim + ) + else: + self.position_embedding = SinusoidalPositionalEmbedding( + embedding_dim=embedding_dim, + max_length=max_position_embeddings, + ) + self.embedding_dropout = nn.Dropout(dropout) + + blocks: list[_AutoregressiveTransformerBlock] = [] + for _ in range(num_layers): + blocks.append( + _AutoregressiveTransformerBlock( + embed_dim=embedding_dim, + num_heads=num_heads, + ff_hidden_dim=ff_hidden_dim, + dropout=dropout, + ) + ) + + self.layers = nn.ModuleList(blocks) + self.final_norm = nn.LayerNorm(embedding_dim) + self.output_projection = nn.Linear(embedding_dim, vocab_size) + self.key_names = [f"key_{idx}" for idx in range(num_layers)] + self.value_names = [f"value_{idx}" for idx in range(num_layers)] + + def init_carry( + self, + batch_size: int, + device: torch.device, + ) -> dict[str, torch.Tensor]: + weight = self.token_embedding.weight + carry: dict[str, torch.Tensor] = { + "position": torch.zeros(batch_size, dtype=torch.long, device=device), + } + empty_key = weight.new_empty(batch_size, self.num_heads, 0, self.head_dim).to( + device + ) + empty_value = weight.new_empty(batch_size, self.num_heads, 0, self.head_dim).to( + device + ) + for key_name, value_name in zip(self.key_names, self.value_names): + carry[key_name] = empty_key.clone() + carry[value_name] = empty_value.clone() + + return carry + + def forward( + self, + x: torch.Tensor, + carry: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + if x.dim() != 2: + raise ValueError("Expected input tensor with shape (batch, timesteps).") + + batch, timesteps = x.size() + device = x.device + if "position" not in carry: + raise KeyError("Carry must include a 'position' tensor.") + + positions = carry["position"] + if positions.size(0) != batch: + raise ValueError( + "Position carry batch dimension does not match the provided tokens." + ) + if positions.device != device: + raise ValueError( + "Position tensor must live on the same device as input tokens." + ) + if torch.any(positions >= self.max_position_embeddings): + raise ValueError( + "Position index exceeds configured positional embedding range." + ) + + position_offsets = torch.arange(timesteps, device=device, dtype=positions.dtype) + position_indices = positions.unsqueeze(1) + position_offsets + if torch.any(position_indices >= self.max_position_embeddings): + raise ValueError( + "Position index exceeds configured positional embedding range." + ) + + hidden = self.token_embedding(x) + self.position_embedding(position_indices) + hidden = self.embedding_dropout(hidden) + + updated_carry: dict[str, torch.Tensor] = {} + + for idx, layer in enumerate(self.layers): + key_name = self.key_names[idx] + value_name = self.value_names[idx] + if key_name not in carry or value_name not in carry: + raise KeyError( + "Transformer carry is missing key/value tensors for layer" f" {idx}." + ) + key_carry = carry[key_name] + value_carry = carry[value_name] + if key_carry.size(0) != batch or key_carry.size(1) != self.num_heads: + raise ValueError( + "Key carry shape is incompatible with the provided tokens." + ) + if value_carry.size(0) != batch or value_carry.size(1) != self.num_heads: + raise ValueError( + "Value carry shape is incompatible with the provided tokens." + ) + if ( + key_carry.size(-1) != self.head_dim + or value_carry.size(-1) != self.head_dim + ): + raise ValueError("Key/value carry head dimension mismatch detected.") + if key_carry.device != device or value_carry.device != device: + raise ValueError("Key/value carry tensors must share the input device.") + hidden, updated_key_carry, updated_value_carry = layer( + hidden, key_carry, value_carry + ) + updated_carry[key_name] = updated_key_carry + updated_carry[value_name] = updated_value_carry + + hidden = self.final_norm(hidden) + logits = self.output_projection(hidden) + + updated_carry["position"] = positions + timesteps + return logits, updated_carry + + @property + def vocab_size(self) -> int: + return self._vocab_size + + +def sinusoidal_position_encoding( + length: int, + embedding_dim: int, + base: float = 10000.0, +) -> Tensor: + """Create 1D sinusoidal positional embeddings. + + Args: + length: Number of positions to encode. Must be non-negative. + embedding_dim: Dimensionality of each embedding. Must be positive. + base: Exponential base used to compute the angular frequencies. + + Returns: + A ``(length, embedding_dim)`` tensor of sinusoidal encodings. + + Raises: + ValueError: If ``length`` is negative, ``embedding_dim`` is not positive, + or ``base`` is not positive. + """ + + assert length >= 0, "length must be non-negative." + assert embedding_dim > 0, "embedding_dim must be positive." + assert base > 0, "base must be positive." + + if length == 0: + return torch.empty(0, embedding_dim) + + positions = torch.arange(length).unsqueeze(1) + div_input = torch.arange(0, embedding_dim, 2) + div_term = torch.exp(div_input * (-math.log(base) / embedding_dim)) + embeddings = torch.zeros(length, embedding_dim) + angles = positions * div_term + embeddings[:, 0::2] = torch.sin(angles) + + if embedding_dim % 2 == 0: + embeddings[:, 1::2] = torch.cos(angles) + else: + embeddings[:, 1::2] = torch.cos(angles)[:, : embedding_dim // 2] + + return embeddings + + +class SinusoidalPositionalEmbedding(nn.Module): + """Sinusoidal positional embeddings for transformer-style models. + + The module caches a precomputed table of embeddings and extends it on demand. + Forward accepts either a sequence length or explicit position indices. + """ + + def __init__( + self, + embedding_dim: int, + max_length: int = 2048, + base: float = 10000.0, + ) -> None: + super().__init__() + assert max_length >= 0, "max_length must be non-negative." + assert embedding_dim > 0, "embedding_dim must be positive." + assert base > 0, "base must be positive." + + self.embedding_dim = int(embedding_dim) + self.base = float(base) + + pe = sinusoidal_position_encoding(max_length, self.embedding_dim, base=self.base) + self._pe: Tensor + self.register_buffer("_pe", pe) + + @property + def pe(self) -> Tensor: + """Return the cached positional embedding table.""" + return self._pe + + def forward( + self, + positions: Optional[Tensor] = None, + seq_len: Optional[int] = None, + ) -> Tensor: + """Look up positional embeddings. + + Args: + positions: Optional tensor of position indices. Can have any shape, + and the returned embeddings will append ``embedding_dim`` to that + shape. Defaults to ``None``. + seq_len: Optional sequence length. When provided, returns the first + ``seq_len`` embeddings from the table. + + Returns: + Tensor of positional embeddings on the same device/dtype as the + cached table. + + Raises: + ValueError: If both or neither of ``positions`` and ``seq_len`` are + provided, or if indices exceed the cached range. + """ + + if (positions is None) == (seq_len is None): + raise ValueError("Provide exactly one of positions or seq_len.") + + if positions is not None: + flat_positions = positions.reshape(-1) + gathered = self._pe.index_select(0, flat_positions) + return gathered.view( + positions.shape[0], positions.shape[1], self.embedding_dim + ) + else: + return self._pe[:seq_len] diff --git a/src/gfn/utils/prob_calculations.py b/src/gfn/utils/prob_calculations.py index 7f0f547f..41b24c3f 100644 --- a/src/gfn/utils/prob_calculations.py +++ b/src/gfn/utils/prob_calculations.py @@ -1,43 +1,9 @@ -from typing import Tuple +from typing import Any, Tuple, cast import torch from gfn.containers import Trajectories, Transitions -from gfn.estimators import Estimator -from gfn.states import States -from gfn.utils.handlers import ( - has_conditioning_exception_handler, - no_conditioning_exception_handler, -) - - -def check_cond_forward( - module: Estimator, - module_name: str, - states: States, - condition: torch.Tensor | None = None, -) -> torch.Tensor: - """Checks if conditioning is passed and calls the module's forward method accordingly. - - Args: - module: The GFN module to call. - module_name: The name of the module (for error messages). - states: The states to pass to the module. - condition: Optional conditioning tensor to pass to the module. - - Returns: - The output of the module's forward method. - - Raises: - TypeError: If conditioning is passed but the module does not accept it, or vice-versa. - """ - if condition is not None: - with has_conditioning_exception_handler(module_name, module): - return module(states, condition) - else: - with no_conditioning_exception_handler(module_name, module): - return module(states) - +from gfn.estimators import Estimator, PolicyEstimatorProtocol, RecurrentPolicyMixin # ------------ # Trajectories @@ -50,33 +16,41 @@ def get_trajectory_pfs_and_pbs( trajectories: Trajectories, fill_value: float = 0.0, recalculate_all_logprobs: bool = True, + **policy_kwargs: Any, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculates the log probabilities of forward and backward trajectories. + """Calculate PF and PB log‑probabilities for trajectories. + + Delegates to ``get_trajectory_pfs`` and ``get_trajectory_pbs`` while + forwarding policy kwargs. Args: - pf: The forward policy estimator. - pb: The backward policy estimator, or None if the gflownet DAG is a tree, and - pb is therefore always 1. - trajectories: The trajectories to calculate probabilities for. - fill_value: The value to fill for invalid states (e.g., sink states). - recalculate_all_logprobs: Whether to recalculate log probabilities even if they - already exist in the trajectories object. + pf: Forward policy estimator. + pb: Backward policy estimator, or ``None`` for trees (PB=1). + trajectories: Trajectories to evaluate. + fill_value: Value used to pad invalid positions. + recalculate_all_logprobs: If True, recompute PF even if cached. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: - A tuple containing two tensors: log_pf_trajectories and log_pb_trajectories. + ``(log_pf[T,N], log_pb[T,N])`` """ + # TODO: Remove this assertion and move to a test. # fill value is the value used for invalid states (sink state usually) - - # uncomment next line for debugging - # assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy) + assert trajectories.states.is_sink_state[:-1].equal(trajectories.actions.is_dummy) log_pf_trajectories = get_trajectory_pfs( pf, trajectories, fill_value=fill_value, recalculate_all_logprobs=recalculate_all_logprobs, + **policy_kwargs, + ) + log_pb_trajectories = get_trajectory_pbs( + pb, + trajectories, + fill_value=fill_value, + **policy_kwargs, ) - log_pb_trajectories = get_trajectory_pbs(pb, trajectories, fill_value=fill_value) return log_pf_trajectories, log_pb_trajectories @@ -86,22 +60,32 @@ def get_trajectory_pfs( trajectories: Trajectories, fill_value: float = 0.0, recalculate_all_logprobs: bool = True, + **policy_kwargs: Any, ) -> torch.Tensor: - """Calculates the log probabilities of forward trajectories. + """Calculate PF log‑probabilities for trajectories. + + Non‑vectorized (per‑step) evaluation with masks + ``~is_sink_state[t] & ~is_dummy[t]`` & no action‑id indexing is supported when + specifically needed (estimator.is_vectorized=False). Args: - pf: The forward policy estimator. - trajectories: The trajectories to calculate probabilities for. - fill_value: The value to fill for invalid states (e.g., sink states). - recalculate_all_logprobs: Whether to recalculate log probabilities even if they - already exist in the trajectories object. + pf: Forward policy estimator. + trajectories: Trajectories to evaluate. + fill_value: Value used to pad invalid positions. + recalculate_all_logprobs: If True, recompute PF even if cached. Useful for + off-policy training. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: - A tensor containing the log probabilities of the forward trajectories. + ``log_pf`` of shape ``(T, N)``. Raises: ValueError: If backward trajectories are provided. """ + # TODO: Ensure that the estimator is policy-capable here. + if not hasattr(pf, "init_context"): + raise TypeError("Estimator is not policy-capable (missing PolicyMixin)") + if trajectories.is_backward: raise ValueError("Backward trajectories are not supported") @@ -118,38 +102,136 @@ def get_trajectory_pfs( log_pf_trajectories = trajectories.log_probs assert log_pf_trajectories is not None else: - log_pf_trajectories = torch.full_like( - trajectories.actions.tensor[..., 0], - fill_value=fill_value, - dtype=torch.get_default_dtype(), # Floating point dtype. - ) - if len(valid_states) == 0: - return log_pf_trajectories + # Decide vectorized vs non-vectorized based on estimator capability + # Tell the type-checker we expect the Policy mixin surface here. + policy_pf = cast(PolicyEstimatorProtocol, pf) + # Runtime guard: ensure the estimator actually implements the required protocol + # method and raises an error when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_pf, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + is_vectorized = bool(getattr(policy_pf, "is_vectorized", True)) + + if not is_vectorized: + # Per-step path. + N = trajectories.n_trajectories + device = trajectories.states.device + cond = trajectories.conditioning + + # TODO: Why do we need this? + if cond is not None and len(cond.shape) >= 2: + cond = cond[0] + + ctx = policy_pf.init_context(int(N), device, cond) # type: ignore[arg-type] + + T = trajectories.max_length + log_pf_trajectories = torch.full( + (T, N), + fill_value=fill_value, + dtype=torch.get_default_dtype(), + device=device, + ) + + for t in range(T): + state_ok = ~trajectories.states.is_sink_state[t] + action_ok = ~trajectories.actions.is_dummy[t] + step_mask = state_ok & action_ok + + if not torch.any(step_mask): + continue + + step_states = trajectories.states[t][step_mask] + step_actions = trajectories.actions.tensor[t][step_mask] + + # Optimization: forward cached estimator outputs when available + if ( + trajectories.estimator_outputs is not None + and not recalculate_all_logprobs + ): + ctx.current_estimator_output = trajectories.estimator_outputs[t][ + step_mask + ] + else: + # Ensure we do not accidentally reuse estimator outputs from a + # previous time step. Precomputed outputs must be provided + # explicitly for the current step. + ctx.current_estimator_output = None + + # Build distribution for active rows and compute step log-probs + dist, ctx = policy_pf.compute_dist( + step_states, ctx, step_mask, **policy_kwargs + ) + step_log_probs, ctx = policy_pf.log_probs( + step_actions, dist, ctx, step_mask, vectorized=False + ) + + # Pad back to full batch size. + if fill_value != 0.0: + padded = torch.full( + (N,), fill_value, device=device, dtype=step_log_probs.dtype + ) + padded[step_mask] = step_log_probs[step_mask] + step_log_probs = padded + + # Store in trajectory-level tensor. + log_pf_trajectories[t] = step_log_probs - if trajectories.estimator_outputs is not None and not recalculate_all_logprobs: - estimator_outputs = trajectories.estimator_outputs[action_mask] else: + # Vectorized path. + log_pf_trajectories = torch.full_like( + trajectories.actions.tensor[..., 0], + fill_value=fill_value, + dtype=torch.get_default_dtype(), + ) + + if len(valid_states) == 0: + return log_pf_trajectories + + # Build conditioning per-step shape to align with valid_states masked_cond = None - if trajectories.conditioning is not None: - cond_dim = (-1,) * len(trajectories.conditioning.shape) - traj_len = trajectories.states.tensor.shape[0] - masked_cond = trajectories.conditioning.unsqueeze(0).expand( - (traj_len,) + cond_dim - )[state_mask] - - estimator_outputs = check_cond_forward(pf, "pf", valid_states, masked_cond) - - # Calculates the log PF of the actions sampled off policy. - valid_log_pf_actions = pf.to_probability_distribution( - valid_states, estimator_outputs - ).log_prob( - valid_actions.tensor - ) # Using the actions sampled off-policy. - - log_pf_trajectories[action_mask] = valid_log_pf_actions.to( - log_pf_trajectories.dtype, copy=False - ) + cond = trajectories.conditioning + + if cond is not None: + T = trajectories.states.tensor.shape[0] + # If conditioning already has time dim (T, N, ...), index directly. + if cond.shape[0] == T: + masked_cond = cond[state_mask] + else: + # Broadcast (N, ...) to (T, N, ...), then index. + masked_cond = cond.unsqueeze(0).expand((T,) + cond.shape)[state_mask] + + ctx_v = policy_pf.init_context( + int(len(valid_states)), + trajectories.states.device, + conditioning=masked_cond, + ) + + # Optional estimator output cache reuse. + if ( + trajectories.estimator_outputs is not None + and not recalculate_all_logprobs + ): + estimator_outputs = trajectories.estimator_outputs[action_mask] + ctx_v.current_estimator_output = estimator_outputs + + # Build distribution and compute vectorized log-probs + dist, ctx_v = policy_pf.compute_dist( + valid_states, + ctx_v, + step_mask=None, + **policy_kwargs, + ) + valid_log_pf_actions, _ = policy_pf.log_probs( + valid_actions.tensor, dist, ctx_v, step_mask=None, vectorized=True + ) + + # Pad back to full batch size. + log_pf_trajectories[action_mask] = valid_log_pf_actions.to( + log_pf_trajectories.dtype, copy=False + ) assert log_pf_trajectories.shape == ( trajectories.max_length, @@ -163,17 +245,24 @@ def get_trajectory_pbs( pb: Estimator | None, trajectories: Trajectories, fill_value: float = 0.0, + **policy_kwargs: Any, ) -> torch.Tensor: - """Calculates the log probabilities of backward trajectories. + """Calculate PB log‑probabilities for trajectories. + + Non‑vectorized (per‑step) evaluation with with alignment + (action at ``t`` with state at ``t+1``) and mask + ``~is_sink_state[t+1] & ~is_initial_state[t+1] & ~is_dummy[t] & ~is_exit[t]``; + skip ``t==0``. is supported when specifically needed + (estimator.is_vectorized=False). Args: - pb: The backward policy estimator. - trajectories: The trajectories to calculate probabilities for. - fill_value: The value to fill for invalid states (e.g., sink states). - dtype: The dtype of the log probabilities. + pb: Backward policy estimator, or ``None`` for trees (PB=1). + trajectories: Trajectories to evaluate. + fill_value: Value used to pad invalid positions. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: - A tensor containing the log probabilities of the backward trajectories. + ``log_pb`` of shape ``(T, N)``. Raises: ValueError: If backward trajectories are provided. @@ -207,26 +296,104 @@ def get_trajectory_pbs( # Using all non-initial states, calculate the backward policy, and the logprobs # of those actions. masked_cond = None - if trajectories.conditioning is not None: - # We need to index the conditioning vector to broadcast over the states. - # The conditioning tensor has shape (max_length, n_trajectories, 1) - # We need to index it with the state_mask to get the valid states - masked_cond = trajectories.conditioning[state_mask] - - if pb is not None: - estimator_outputs = check_cond_forward(pb, "pb", valid_states, masked_cond) - valid_log_pb_actions = pb.to_probability_distribution( - valid_states, estimator_outputs - ).log_prob(valid_actions.tensor) + cond = trajectories.conditioning + if cond is not None: + T = trajectories.states.tensor.shape[0] + if cond.shape[0] == T: + masked_cond = cond[state_mask] + else: + masked_cond = cond.unsqueeze(0).expand((T,) + cond.shape)[state_mask] - else: + # There is no backward policy in this case. + if pb is None: # If pb is None, we assume that the gflownet DAG is a tree, and therefore # the backward policy probability is always 1 (log probs are 0). valid_log_pb_actions = torch.zeros_like(valid_actions.tensor) + valid_log_pb_actions = valid_log_pb_actions.squeeze(-1) # no padding. + log_pb_trajectories[action_mask] = valid_log_pb_actions.to( + log_pb_trajectories.dtype, copy=False + ) - log_pb_trajectories[action_mask] = valid_log_pb_actions.to( - log_pb_trajectories.dtype, copy=False - ) + assert log_pb_trajectories.shape == ( + trajectories.max_length, + trajectories.n_trajectories, + ) + + return log_pb_trajectories + + # There is a backward policy. + policy_pb = cast(PolicyEstimatorProtocol, pb) + # Runtime guard: ensure the estimator actually implements the required protocol + # method and raises an error when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_pb, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + is_vectorized = bool(getattr(policy_pb, "is_vectorized", True)) + + if not is_vectorized: + # Per-step pb evaluation (state at t+1, action at t) + N = trajectories.n_trajectories + device = trajectories.states.device + cond = trajectories.conditioning + if cond is not None and len(cond.shape) >= 2: + cond_step0 = cond[0] # TODO: Why do we need this? + ctx = policy_pb.init_context(int(N), device, cond_step0) # type: ignore[arg-type] + + # Iterate per-step with masking (state at t+1, action at t) + for t in range(trajectories.max_length): + # TODO: these checks are curious - I think one of them is never needed + # because for now we do not support reversed trajectories. + next_state_isnt_sink = ~trajectories.states.is_sink_state[t + 1] + next_state_isnt_initial = ~trajectories.states.is_initial_state[t + 1] + state_ok = next_state_isnt_sink & next_state_isnt_initial + if t == 0: + # log PB is always zero for the transition s1 -> s0. + state_ok = torch.zeros_like(state_ok, dtype=torch.bool) + + action_ok = (~trajectories.actions.is_dummy[t]) & ( + ~trajectories.actions.is_exit[t] + ) + step_mask = state_ok & action_ok + + if not torch.any(step_mask): + continue + + step_states = trajectories.states[t + 1][step_mask] + step_actions = trajectories.actions.tensor[t][step_mask] + + # Prevent reusing last step's estimator output (batch size may differ, + # and estimator output caching isn't needed for PB). + ctx.current_estimator_output = None + dist, ctx = policy_pb.compute_dist( + step_states, ctx, step_mask, **policy_kwargs + ) + step_lp, ctx = policy_pb.log_probs( + step_actions, dist, ctx, step_mask, vectorized=False + ) + + padded = torch.full((N,), fill_value, device=device, dtype=step_lp.dtype) + padded[step_mask] = step_lp[step_mask] + log_pb_trajectories[t] = padded + + # The backward policy supports vectorized evaluation. + else: + ctx_v = policy_pb.init_context( + int(len(valid_states)), trajectories.states.device, conditioning=masked_cond # type: ignore[arg-type] + ) + dist, ctx_v = policy_pb.compute_dist( + valid_states, + ctx_v, + step_mask=None, + **policy_kwargs, + ) + valid_log_pb_actions, _ = policy_pb.log_probs( + valid_actions.tensor, dist, ctx_v, step_mask=None, vectorized=True + ) + log_pb_trajectories[action_mask] = valid_log_pb_actions.to( + log_pb_trajectories.dtype, copy=False + ) assert log_pb_trajectories.shape == ( trajectories.max_length, @@ -246,19 +413,20 @@ def get_transition_pfs_and_pbs( pb: Estimator | None, transitions: Transitions, recalculate_all_logprobs: bool = True, + **policy_kwargs: Any, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Calculates the log probabilities of forward and backward transitions. + """Calculate PF and PB log‑probabilities for transitions. Args: - pf: The forward policy estimator. - pb: The backward policy estimator, or None if the gflownet DAG is a tree, and - pb is therefore always 1. - transitions: The transitions to calculate probabilities for. - recalculate_all_logprobs: Whether to recalculate log probabilities even if they - already exist in the transitions object. + pf: Forward policy estimator. + pb: Backward policy estimator, or ``None`` for trees (PB=1). + transitions: Transitions to evaluate. + recalculate_all_logprobs: If True, recompute PF even if cached. Useful for + off-policy training. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: - A tuple containing two tensors: log_pf_transitions and log_pb_transitions. + ``(log_pf[M], log_pb[M])``. Raises: ValueError: If backward transitions are provided. @@ -266,8 +434,10 @@ def get_transition_pfs_and_pbs( if transitions.is_backward: raise ValueError("Backward transitions are not supported") - log_pf_transitions = get_transition_pfs(pf, transitions, recalculate_all_logprobs) - log_pb_transitions = get_transition_pbs(pb, transitions) + log_pf_transitions = get_transition_pfs( + pf, transitions, recalculate_all_logprobs, **policy_kwargs + ) + log_pb_transitions = get_transition_pbs(pb, transitions, **policy_kwargs) assert log_pf_transitions.shape == (transitions.n_transitions,) assert log_pb_transitions.shape == (transitions.n_transitions,) @@ -276,18 +446,22 @@ def get_transition_pfs_and_pbs( def get_transition_pfs( - pf: Estimator, transitions: Transitions, recalculate_all_logprobs: bool = True + pf: Estimator, + transitions: Transitions, + recalculate_all_logprobs: bool = True, + **policy_kwargs: Any, ) -> torch.Tensor: - """Calculates the log probabilities of forward transitions. + """Calculate PF log‑probabilities for transitions. Args: - pf: The forward policy estimator. - transitions: The transitions to calculate probabilities for. - recalculate_all_logprobs: Whether to recalculate log probabilities even if they - already exist in the transitions object. + pf: Forward policy estimator. + transitions: Transitions to evaluate. + recalculate_all_logprobs: If True, recompute PF even if cached. Useful for + off-policy training. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. Returns: - A tensor containing the log probabilities of the forward transitions. + ``log_pf`` of shape ``(M,)``. """ states = transitions.states actions = transitions.actions @@ -296,40 +470,54 @@ def get_transition_pfs( log_pf_actions = transitions.log_probs assert log_pf_actions is not None else: - # Evaluate the log PF of the actions, with optional conditioning. + + if isinstance(pf, RecurrentPolicyMixin): + raise TypeError("RecurrentPolicyMixin is only supported for Trajectories") + + N = transitions.n_transitions + device = transitions.states.device + cond = transitions.conditioning + + # For static typing, cast to the policy protocol before calling mixin methods. + policy_pf = cast(PolicyEstimatorProtocol, pf) + # Runtime guard: ensure the estimator actually implements the required protocol + # method and raises an error when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_pf, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + ctx = policy_pf.init_context(int(N), device, cond) + mask = torch.ones(N, dtype=torch.bool, device=device) + + # Evaluate the log PF of the actions # TODO: Inefficient duplication in case of tempered policy # The Transitions container should then have some # estimator_outputs attribute as well, to avoid duplication here ? # See (#156). - estimator_outputs = check_cond_forward( - pf, "pf", states, transitions.conditioning + dist, ctx = policy_pf.compute_dist(states[mask], ctx, mask, **policy_kwargs) + log_pf_actions, _ = policy_pf.log_probs( + actions.tensor[mask], dist, ctx, mask, vectorized=False ) - log_pf_actions = pf.to_probability_distribution( - states, estimator_outputs - ).log_prob(actions.tensor) - return log_pf_actions -def get_transition_pbs(pb: Estimator | None, transitions: Transitions) -> torch.Tensor: - """Calculates the log probabilities of backward transitions. +def get_transition_pbs( + pb: Estimator | None, + transitions: Transitions, + **policy_kwargs: Any, +) -> torch.Tensor: + """Calculate PB log‑probabilities for transitions. Args: - pb: The backward policy Estimator, or None if the gflownet DAG is a tree, and - pb is therefore always 1. - transitions: The transitions to calculate probabilities for. + pb: Backward policy estimator, or ``None`` for trees (PB=1). + transitions: Transitions to evaluate. + **policy_kwargs: Extra kwargs for ``to_probability_distribution``. + + Returns: + ``log_pb`` of shape ``(M,)``. """ - # automatically removes invalid transitions (i.e. s_f -> s_f) - valid_next_states = transitions.next_states[~transitions.is_terminating] - non_exit_actions = transitions.actions[~transitions.actions.is_exit] - - # Evaluate the log PB of the actions, with optional conditioning. - masked_cond = ( - transitions.conditioning[~transitions.is_terminating] - if transitions.conditioning is not None - else None - ) # TODO: We support a fill_value for trajectories, but not for transitions. # Should we add it here, or remove it for trajectories? @@ -339,15 +527,44 @@ def get_transition_pbs(pb: Estimator | None, transitions: Transitions) -> torch. # If pb is None, we assume that the gflownet DAG is a tree, and therefore # the backward policy probability is always 1 (log probs are 0). - if pb is not None: - estimator_outputs = check_cond_forward(pb, "pb", valid_next_states, masked_cond) + if pb is None: + return log_pb_actions + + if not hasattr(pb, "init_context"): + raise TypeError("Estimator is not policy-capable (missing PolicyMixin)") + + if isinstance(pb, RecurrentPolicyMixin): + raise TypeError("RecurrentPolicyMixin is only supported for Trajectories") + + # For static typing, cast to the policy protocol before calling mixin methods. + policy_pb = cast(PolicyEstimatorProtocol, pb) + # Runtime guard: ensure the estimator actually implements the required protocol + # method and raises an error when a non‑policy estimator is supplied. + for required in ("init_context", "compute_dist", "log_probs"): + if not hasattr(policy_pb, required): + raise TypeError( + f"Estimator is not policy-capable (missing PolicyMixin method: {required})" + ) + ctx = policy_pb.init_context( + int(transitions.n_transitions), + transitions.states.device, + transitions.conditioning, + ) + + # Legacy-complete masking for PB on transitions: + # require non-terminating next_states and non-exit actions simultaneously + # automatically removes invalid transitions (i.e. s_f -> s_f) + mask = ~transitions.is_terminating & ~transitions.actions.is_exit - # Evaluate the log PB of the actions. - valid_log_pb_actions = pb.to_probability_distribution( - valid_next_states, estimator_outputs - ).log_prob(non_exit_actions.tensor) + if not torch.any(mask): + return log_pb_actions - if len(valid_next_states) != 0: - log_pb_actions[~transitions.is_terminating] = valid_log_pb_actions + dist, ctx = policy_pb.compute_dist( + transitions.next_states[mask], ctx, mask, **policy_kwargs + ) + step_lp, _ = policy_pb.log_probs( + transitions.actions.tensor[mask], dist, ctx, mask, vectorized=False + ) + log_pb_actions[mask] = step_lp[mask] return log_pb_actions diff --git a/testing/test_adaptor_estimator_gflownet_integration.py b/testing/test_adaptor_estimator_gflownet_integration.py new file mode 100644 index 00000000..eb943228 --- /dev/null +++ b/testing/test_adaptor_estimator_gflownet_integration.py @@ -0,0 +1,219 @@ +import warnings + +import pytest +import torch + +from gfn.estimators import ( + DiscretePolicyEstimator, + RecurrentDiscretePolicyEstimator, + ScalarEstimator, +) +from gfn.gflownet import DBGFlowNet, TBGFlowNet +from gfn.gym.bitSequence import BitSequence +from gfn.preprocessors import IdentityPreprocessor +from gfn.utils.modules import MLP, RecurrentDiscreteSequenceModel + + +def _make_bitsequence_env( + *, device: torch.device, word_size: int = 3, seq_size: int = 9, n_modes: int = 5 +) -> BitSequence: + H = torch.randint(0, 2, (n_modes, seq_size), dtype=torch.long, device=device) + env = BitSequence( + word_size=word_size, + seq_size=seq_size, + n_modes=n_modes, + temperature=1.0, + H=H, + device_str=str(device), + seed=0, + check_action_validity=True, + ) + return env + + +def _make_recurrent_pf( + env: BitSequence, device: torch.device +) -> RecurrentDiscretePolicyEstimator: + model = RecurrentDiscreteSequenceModel( + vocab_size=env.n_actions, + embedding_dim=16, + hidden_size=32, + num_layers=1, + rnn_type="lstm", + dropout=0.0, + ).to(device) + pf = RecurrentDiscretePolicyEstimator( + module=model, n_actions=env.n_actions, is_backward=False + ).to(device) + return pf + + +def _make_nonrecurrent_pf_pb(env: BitSequence, device: torch.device): + # BitSequence states are integer words of length words_per_seq + input_dim = env.words_per_seq + preprocessor = IdentityPreprocessor(output_dim=input_dim) + + pf_module = MLP( + input_dim=input_dim, output_dim=env.n_actions, hidden_dim=32, n_hidden_layers=1 + ).to(device) + pb_module = MLP( + input_dim=input_dim, + output_dim=env.n_actions - 1, + hidden_dim=32, + n_hidden_layers=1, + ).to(device) + pf = DiscretePolicyEstimator( + module=pf_module, + n_actions=env.n_actions, + is_backward=False, + preprocessor=preprocessor, + ).to(device) + pb = DiscretePolicyEstimator( + module=pb_module, + n_actions=env.n_actions, + is_backward=True, + preprocessor=preprocessor, + ).to(device) + return pf, pb + + +def test_recurrent_tb_passes_with_pb_none(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + pf = _make_recurrent_pf(env, device) + gfn = TBGFlowNet(pf=pf, pb=None, init_logZ=0.0, constant_pb=True) + + # sample and compute a loss to ensure end-to-end path works + trajectories = gfn.sample_trajectories( + env, n=4, save_logprobs=True, save_estimator_outputs=False + ) + loss = gfn.loss(env, trajectories, recalculate_all_logprobs=False) + assert torch.isfinite(loss) + + +def test_warn_on_recurrent_pf_with_nonrecurrent_pb(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + pf = _make_recurrent_pf(env, device) + pb_pf, pb = _make_nonrecurrent_pf_pb(env, device) + del pb_pf # unused + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + _ = TBGFlowNet(pf=pf, pb=pb, init_logZ=0.0, constant_pb=False) + assert any("unusual" in str(x.message).lower() for x in w) + + +def test_error_on_recurrent_pb(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + pf_nonrec, _ = _make_nonrecurrent_pf_pb(env, device) + + # Build a recurrent PB + model = RecurrentDiscreteSequenceModel( + vocab_size=env.n_actions - 1, + embedding_dim=16, + hidden_size=32, + num_layers=1, + rnn_type="lstm", + dropout=0.0, + ).to(device) + pb_recurrent = RecurrentDiscretePolicyEstimator( + module=model, n_actions=env.n_actions, is_backward=True + ).to(device) + + with pytest.raises(TypeError, match="Recurrent PB estimators are not supported"): + _ = TBGFlowNet(pf=pf_nonrec, pb=pb_recurrent, init_logZ=0.0, constant_pb=False) + + +def test_db_gflownet_rejects_recurrent_pf_and_adapter(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + pf = _make_recurrent_pf(env, device) + + # recurrent PF should be rejected + logF_est = ScalarEstimator( + module=MLP( + input_dim=env.words_per_seq, + output_dim=1, + hidden_dim=16, + n_hidden_layers=1, + ).to(device) + ) + with pytest.raises(TypeError, match="does not support recurrent PF"): + _ = DBGFlowNet( + pf=pf, + pb=None, + logF=logF_est, + constant_pb=True, + ) # type: ignore[arg-type] + + # Non-recurrent PF should be accepted (adapters are now part of estimators) + pf_nonrec, _ = _make_nonrecurrent_pf_pb(env, device) + _ = DBGFlowNet( + pf=pf_nonrec, + pb=None, + logF=logF_est, + constant_pb=True, + ) # type: ignore[arg-type] + + +def test_nonrecurrent_tb_passes_with_pb_defined(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + pf, pb = _make_nonrecurrent_pf_pb(env, device) + gfn = TBGFlowNet(pf=pf, pb=pb, init_logZ=0.0, constant_pb=False) + + trajectories = gfn.sample_trajectories( + env, n=3, save_logprobs=True, save_estimator_outputs=False + ) + loss = gfn.loss(env, trajectories, recalculate_all_logprobs=False) + assert torch.isfinite(loss) + + +def test_pb_mlp_trunk_sharing_parity_on_transitions(): + device = torch.device("cpu") + env = _make_bitsequence_env(device=device) + + # Build non-recurrent PF for sampling + pf, _ = _make_nonrecurrent_pf_pb(env, device) + + # PB with trunk sharing from PF + pb_shared_module = MLP( + input_dim=env.words_per_seq, + output_dim=env.n_actions - 1, + hidden_dim=32, + n_hidden_layers=1, + trunk=pf.module.trunk, # type: ignore[attr-defined] + ).to(device) + pb_shared = DiscretePolicyEstimator( + module=pb_shared_module, n_actions=env.n_actions, is_backward=True + ).to(device) + + # PB independent with identical weights copied from shared version + pb_indep_module = MLP( + input_dim=env.words_per_seq, + output_dim=env.n_actions - 1, + hidden_dim=32, + n_hidden_layers=1, + ).to(device) + pb_indep_module.load_state_dict(pb_shared_module.state_dict()) + pb_indep = DiscretePolicyEstimator( + module=pb_indep_module, n_actions=env.n_actions, is_backward=True + ).to(device) + + # Sample trajectories and convert to transitions + from gfn.samplers import Sampler + + sampler = Sampler(estimator=pf) + trajectories = sampler.sample_trajectories( + env, n=5, save_logprobs=False, save_estimator_outputs=False + ) + transitions = trajectories.to_transitions() + + # Compute PB log-probs using vectorized default adapters for each PB + from gfn.utils.prob_calculations import get_transition_pbs + + lp_shared = get_transition_pbs(pb_shared, transitions) + lp_indep = get_transition_pbs(pb_indep, transitions) + + torch.testing.assert_close(lp_shared, lp_indep) diff --git a/testing/test_chunking.py b/testing/test_chunking.py new file mode 100644 index 00000000..baedf74b --- /dev/null +++ b/testing/test_chunking.py @@ -0,0 +1,375 @@ +from __future__ import annotations + +import random +from typing import List, Sequence, cast + +import torch +import torch.nn as nn + +from gfn.chunking.policy import ActionEncoder, ChunkedPolicy +from gfn.containers import Trajectories +from gfn.env import ChunkedDiscreteEnvironment +from gfn.states import ChunkedStates, DiscreteStates + +# from gfn.chunking.adapters import ChunkedAdapter + + +class SyntheticTokenEnv(ChunkedDiscreteEnvironment): + def __init__(self, device: torch.device = torch.device("cpu")): + # n_actions = 27: A=1,B=2,C=3,D=4,...,Z=26,EXIT=27 + n_actions = 27 + s0 = torch.tensor([0], device=device) + # Tokenizer maps primitive ints to letters for string-based chunkers. + + def _letters_tokenizer(seq: Sequence[int]) -> str: + alpha = {i: chr(ord("A") + i - 1) for i in range(1, 27)} # 1->A ... 26->Z + alpha[27] = "EXIT" + + return "".join(alpha.get(i, "") for i in seq) + + super().__init__( + n_actions=n_actions, + s0=s0, + state_shape=(1,), + action_shape=(1,), + check_action_validity=True, + tokenizer=_letters_tokenizer, + exit_action=torch.tensor([27], device=device), + ) + + def update_masks(self, states: ChunkedStates) -> None: + # Forward mask: disallow [B,A] and [C,D] + # Backward mask: inverse constraints: if curr==A, disallow parent B; if curr==D, disallow parent C + device = states.device + batch_shape = states.batch_shape + + # Initialize all true for both masks with correct batch shape + fwd = torch.ones((*batch_shape, self.n_actions), dtype=torch.bool, device=device) + bwd = torch.ones( + (*batch_shape, self.n_actions - 1), dtype=torch.bool, device=device + ) + + # Current token value per batch element (handles (B,1) and (T,B,1)) + last = states.tensor.squeeze(-1) # shape == (*batch_shape,) + + # Forward constraints + disallow_A = last == 2 # if last == B --> disallow A + if disallow_A.any(): + # Mask primitive A (index 0) + fwd[..., 1].masked_fill_(disallow_A, False) + disallow_D = last == 3 # if last == C --> disallow D + if disallow_D.any(): + # Mask primitive D (index 3) + fwd[..., 4].masked_fill_(disallow_D, False) + + # Backward constraints (no EXIT column) + disallow_parent_B = last == 1 # current == A -> disallow parent B + if disallow_parent_B.any(): + bwd[..., 2].masked_fill_(disallow_parent_B, False) + + disallow_parent_C = last == 4 # current == D -> disallow parent C + if disallow_parent_C.any(): + bwd[..., 3].masked_fill_(disallow_parent_C, False) + + states.forward_masks = fwd + states.backward_masks = bwd + + # Overlay global disables and macro feasibility + self.apply_soft_disabled_to_forward_masks(states) + self.apply_macro_forward_mask(states) + + def step(self, states: DiscreteStates, actions) -> DiscreteStates: + # Set state to the action token, unless EXIT; EXIT leads to sink (-1) + device = states.device + a = actions.tensor # preserve shape (matches states.tensor) + new = states.tensor.clone() + # For any EXIT, set sink + is_exit = a == (self.n_actions - 1) + new[is_exit] = torch.tensor([-1], device=device, dtype=new.dtype) + # For non-exit, set to the chosen primitive token id + non_exit = ~is_exit + if a.dtype != new.dtype: + a = a.to(dtype=new.dtype) + new[non_exit] = a[non_exit] + out = self.states_from_tensor(new) + return out + + def backward_step(self, states: DiscreteStates, actions) -> DiscreteStates: + # For synthetic tests, just revert to s0 for simplicity when not exit + device = states.device + a = actions.tensor.view(-1) + new = states.tensor.clone() + # Non-exit moves to a dummy parent; for tests we can use 0 + new[a != (self.n_actions - 1)] = torch.tensor( + [0], device=device, dtype=new.dtype + ) + out = self.states_from_tensor(new) + return out + + +def generate_synthetic_corpus( + n_traj: int = 10000, + length: int = 20, + device: torch.device = torch.device("cpu"), +) -> Trajectories: + # Build N trajectories respecting forward constraints and injecting chunks + # Tokens 1..4 only; no EXIT in the corpus + env = SyntheticTokenEnv(device) + + actions_2d = torch.zeros(length, n_traj, dtype=torch.long, device=device) + term = torch.full((n_traj,), length, dtype=torch.long, device=device) + + subseq = [2, 3, 1, 4, 2] # "BCADB" + for i in range(n_traj): + seq: List[int] = [] + + # Choose an insertion index for BCADB that fits in the sequence + ins_start = random.randint(0, length - len(subseq)) + + t = 0 + while t < length: + + # Add the subsequence here. + if t == ins_start: + seq.extend(subseq) + t += len(subseq) + continue + + last = seq[-1] if seq else 0 + candidates = list(range(1, 27)) + + # Forward constraints from SyntheticTokenEnv.update_masks + if last == 2: # after B, disallow A + candidates.remove(1) + if last == 3: # after C, disallow D + candidates.remove(4) + seq.append(random.choice(candidates)) + t += 1 + + actions_2d[:, i] = torch.tensor(seq[:length], dtype=torch.long, device=device) + + # Wrap into env-specific containers + actions = env.actions_from_tensor(actions_2d.unsqueeze(-1)) + + # Derive states by unrolling tokens as last-observation states + states_tensor = torch.zeros(length + 1, n_traj, 1, dtype=torch.long, device=device) + states_tensor[0, :, 0] = 0 # s0 + states_tensor[1:, :, 0] = actions_2d + states = cast(ChunkedStates, env.states_from_tensor(states_tensor)) + + return Trajectories( + env=env, + states=states, + actions=actions, + terminating_idx=term, + is_backward=False, + ) + + +class TinyStateModule(nn.Module): + def __init__(self, embed_dim: int = 16): + super().__init__() + self.embed = nn.Embedding( + 6, embed_dim + ) # allow -1,0..4 remapped in preprocessing + self.proj = nn.Linear(embed_dim, 32) + + def forward(self, states: DiscreteStates) -> torch.Tensor: + x = states.tensor.view(-1) + # Remap -1->0, keep 0..4 as is plus 1 offset to avoid negative indices + x = torch.clamp(x + 1, min=0) + e = self.embed(x) + out = self.proj(e) + return out + + +class _ConstState(nn.Module): + """Return a constant state embedding for any input batch. + + This isolates the test to the interaction between `ChunkedPolicy` and + `ActionEncoder`, avoiding any dependency on a learned state network. + """ + + def __init__(self, embed_dim: int) -> None: + super().__init__() + self.embedding = nn.Parameter(torch.zeros(embed_dim), requires_grad=False) + + def forward(self, states: DiscreteStates) -> torch.Tensor: + batch = states.batch_shape[0] + return self.embedding.expand(batch, -1) + + +def test_policy_encoder_growing_action_space_with_synthetic_env(): + # Ensure deterministic encoder behavior + torch.manual_seed(0) + + device = torch.device("cpu") + D = 32 + + # Reuse the synthetic environment with primitives A..Z and EXIT + env = SyntheticTokenEnv(device) + + # Build a real batch of discrete states (values don't matter for constant state net) + states = env.states_from_tensor(torch.tensor([[0], [1], [2]], device=device)) + + # Encoder maps sequences of primitive ids to action embeddings in R^D + encoder = ActionEncoder( + n_primitive_actions=env.n_actions, # primitives + EXIT + action_embedding_dimension=D, + hidden_dim=32, + num_layers=1, + num_head=4, + max_len=8, + dropout=0.0, + ) + + # Policy produces logits via scaled dot product between state and action embeddings + policy = ChunkedPolicy(_ConstState(D), encoder, env, action_embedding_dim=D) + + # First pass: primitives only + logits1 = policy.forward_logits(states) + assert logits1.shape == (states.batch_shape[0], env.n_actions) + emb1 = policy._library_embeddings.detach().clone() + assert torch.isfinite(emb1).all() + + # Grow action space by adding a length-5 macro (BCADB) + env.add_tokens([(2, 3, 1, 4, 2)]) + logits2 = policy.forward_logits(states) + assert logits2.shape == (states.batch_shape[0], env.n_actions) + emb2 = policy._library_embeddings.detach().clone() + + # Existing embeddings should be unchanged after refresh + assert torch.allclose(emb2[: emb1.shape[0]], emb1, atol=1e-1) + + # Grow again with a different-length macro + env.add_tokens([(3, 3, 3)]) + logits3 = policy.forward_logits(states) + assert logits3.shape == (states.batch_shape[0], env.n_actions) + emb3 = policy._library_embeddings.detach() + assert emb3.shape[0] == env.n_actions + assert torch.isfinite(emb3).all() + + # Check scaled dot-product formula: l_t = f_θ(A) q_t / sqrt(d) + state_emb = policy.state_module(states) + expected = torch.einsum("bd,nd->bn", state_emb, emb3) / (D**0.5) + assert torch.allclose(logits3, expected, atol=1e-6) + + +def test_mining_finds_chunks(): + from gfn.chunking.chunkers import BPEChunker, WordPieceChunker + + device = torch.device("cpu") + trajs = generate_synthetic_corpus(2000, 20, device) + env = trajs.env # SyntheticTokenEnv with letters tokenizer. + + # Propose with BPE and WordPiece; expect 'BCADB' among new tokens (present in all sequences) + bpe = BPEChunker(unk_token="[UNK]", delimiter="") + wp = WordPieceChunker(unk_token="[UNK]", delimiter="") + + proposed_bpe = set( + bpe.propose_tokens(env, trajs, n_tokens_to_add=50, remove_old=False) + ) + proposed_wp = set( + wp.propose_tokens(env, trajs, n_tokens_to_add=50, remove_old=False) + ) + + assert "BCADB" in proposed_bpe + assert "BCADB" in proposed_wp + + +def test_macro_masking(): + device = torch.device("cpu") + trajs = generate_synthetic_corpus(2000, 20, device) + env = trajs.env # SyntheticTokenEnv with letters tokenizer. + + # Add macro keys to env vocab (tuple form for executable macros) + new_ids = env.add_tokens([(2, 3, 1, 4, 2)]) + assert len(new_ids) == 1 + assert env.id_to_token_key[-1] == (2, 3, 1, 4, 2) # BCADB, the new macro-action. + assert env.id_to_token_key[-2] == "" + assert env.id_to_token_key[:26] == list(range(26)) # The alphabet. + + # Macro should be feasible from a generic state (start at s0) + state_0 = env.states_from_tensor(torch.tensor([[0]], device=device)) + env.update_masks(state_0) + macro_id = new_ids[0] # BCADB, allowed. + assert state_0.forward_masks[0, macro_id].item() + + # Macro should be infeasable from a generic state. + new_ids = env.add_tokens([(2, 3, 4, 4, 2)]) # BCDDB, disallowed (no C->D allowed). + state_0 = env.states_from_tensor(torch.tensor([[0]], device=device)) + env.update_masks(state_0) + macro_id = new_ids[0] # BCDDB, disallowed. + assert not state_0.forward_masks[0, macro_id].item() + + +def test_macro_mask_guard_no_recursion_batch_only(): + device = torch.device("cpu") + env = SyntheticTokenEnv(device) + # Simple (B,) batch + B = 2 + states_tensor = torch.zeros((B, 1), dtype=torch.long, device=device) + states = cast(ChunkedStates, env.states_from_tensor(states_tensor)) + + # Should not recurse + env.update_masks(states) + mask = env.compute_strict_macro_forward_mask(states) + assert mask.shape == (states.batch_shape[0], env.n_actions) + assert mask.dtype == torch.bool + assert getattr(env, "_macro_overlay_depth", 0) == 0 + + +def test_macro_mask_guard_no_recursion_trajectories(): + device = torch.device("cpu") + env = SyntheticTokenEnv(device) + # (T,B) batch + T, B = 3, 2 + states_tensor = torch.zeros((T, B, 1), dtype=torch.long, device=device) + states = cast(ChunkedStates, env.states_from_tensor(states_tensor)) + + # Should not recurse + env.update_masks(states) + mask = env.compute_strict_macro_forward_mask(states) + assert mask.shape == (T, B, env.n_actions) + assert mask.dtype == torch.bool + assert getattr(env, "_macro_overlay_depth", 0) == 0 + + +def test_horizon_mask_blocks_oversized_macro(): + device = torch.device("cpu") + env = SyntheticTokenEnv(device) + + # Register a length-3 macro + macro_ids = env.add_tokens([(2, 2, 2)]) + macro_id = macro_ids[0] + + # Build (T=3, B=2) states: remaining steps = 3,2,1 at t=0,1,2 + T, B = 3, 2 + states_tensor = torch.zeros((T, B, 1), dtype=torch.long, device=device) + states = cast(ChunkedStates, env.states_from_tensor(states_tensor)) + env.update_masks(states) + mask = env.compute_strict_macro_forward_mask(states) + + # At t=0: remaining=3 -> macro allowed by horizon check + assert mask[0, :, macro_id].all().item() + + # At t>=1: remaining < 3 -> macro disallowed + assert (~mask[1:, :, macro_id]).all().item() + + +def test_apply_macro_forward_mask_noop_under_guard(): + device = torch.device("cpu") + env = SyntheticTokenEnv(device) + B = 2 + states_tensor = torch.zeros((B, 1), dtype=torch.long, device=device) + states = cast(ChunkedStates, env.states_from_tensor(states_tensor)) + env.update_masks(states) + fwd_before = states.forward_masks.clone() + + # Manually enter guard + setattr(env, "_macro_overlay_depth", getattr(env, "_macro_overlay_depth", 0) + 1) + try: + env.apply_macro_forward_mask(states) + finally: + setattr(env, "_macro_overlay_depth", getattr(env, "_macro_overlay_depth", 1) - 1) + assert torch.equal(states.forward_masks, fwd_before) diff --git a/testing/test_modules.py b/testing/test_modules.py new file mode 100644 index 00000000..5689f7a6 --- /dev/null +++ b/testing/test_modules.py @@ -0,0 +1,174 @@ +from typing import Literal + +import pytest +import torch + +from gfn.utils.modules import ( + RecurrentDiscreteSequenceModel, + TransformerDiscreteSequenceModel, +) + + +@pytest.mark.parametrize("rnn_type", ["lstm", "gru"]) +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +def test_recurrent_smoke(rnn_type: Literal["lstm", "gru"], device: torch.device) -> None: + batch_size = 2 + vocab_size = 11 + total_steps = 4 + model = RecurrentDiscreteSequenceModel( + vocab_size=vocab_size, + embedding_dim=5, + hidden_size=7, + num_layers=2, + rnn_type=rnn_type, + dropout=0.0, + ).to(device) + model.eval() + + tokens = torch.randint(0, vocab_size, (batch_size, total_steps), device=device) + + def collect_logits( + chunk_sizes: list[int], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + carry = model.init_carry(batch_size, device) + outputs: list[torch.Tensor] = [] + start = 0 + with torch.no_grad(): + for chunk in chunk_sizes: + end = start + chunk + logits, carry = model(tokens[:, start:end], carry) + outputs.append(logits) + start = end + if start != total_steps: + raise ValueError("Chunk sizes must cover the entire sequence length.") + return torch.cat(outputs, dim=1), carry + + logits_all, carry_all = collect_logits([total_steps]) + logits_single, carry_single = collect_logits([1] * total_steps) + logits_double, carry_double = collect_logits([2, 2]) + + scripted = torch.jit.script(model) + carry_script = model.init_carry(batch_size, device) + with torch.no_grad(): + logits_script, carry_script = scripted(tokens, carry_script) + + assert torch.allclose(logits_all, logits_single, atol=1e-6, rtol=1e-5) + assert torch.allclose(logits_all, logits_double, atol=1e-6, rtol=1e-5) + assert torch.allclose(logits_all, logits_script, atol=1e-6, rtol=1e-5) + + assert torch.allclose( + carry_all["hidden"], carry_single["hidden"], atol=1e-6, rtol=1e-5 + ) + assert torch.allclose( + carry_all["hidden"], carry_double["hidden"], atol=1e-6, rtol=1e-5 + ) + + if rnn_type == "lstm": + assert torch.allclose( + carry_all["cell"], carry_single["cell"], atol=1e-6, rtol=1e-5 + ) + assert torch.allclose( + carry_all["cell"], carry_double["cell"], atol=1e-6, rtol=1e-5 + ) + + +@pytest.mark.parametrize("positional_embedding", ["learned", "sinusoidal"]) +@pytest.mark.parametrize( + "device", + [ + "cpu", + pytest.param( + "cuda", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA not available" + ), + ), + ], +) +def test_transformer_smoke( + positional_embedding: Literal["learned", "sinusoidal"], + device: torch.device, +) -> None: + batch_size = 3 + vocab_size = 13 + total_steps = 4 + model = TransformerDiscreteSequenceModel( + vocab_size=vocab_size, + embedding_dim=12, + num_heads=3, + ff_hidden_dim=24, + num_layers=2, + max_position_embeddings=32, + dropout=0.0, + positional_embedding=positional_embedding, + ).to(device) + model.eval() + + tokens = torch.randint(0, vocab_size, (batch_size, total_steps), device=device) + + def collect_logits( + chunk_sizes: list[int], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + carry = model.init_carry(batch_size, device) + outputs: list[torch.Tensor] = [] + start = 0 + with torch.no_grad(): + for chunk in chunk_sizes: + end = start + chunk + logits, carry = model(tokens[:, start:end], carry) + outputs.append(logits) + start = end + if start != total_steps: + raise ValueError("Chunk sizes must cover the entire sequence length.") + return torch.cat(outputs, dim=1), carry + + logits_all, carry_all = collect_logits([total_steps]) + logits_single, carry_single = collect_logits([1] * total_steps) + logits_double, carry_double = collect_logits([2, 2]) + + scripted = torch.jit.script(model) + carry_script = model.init_carry(batch_size, device) + + with torch.no_grad(): + logits_script, carry_script = scripted(tokens, carry_script) + + assert torch.allclose(logits_all, logits_single, atol=1e-6, rtol=1e-5) + assert torch.allclose(logits_all, logits_double, atol=1e-6, rtol=1e-5) + assert torch.allclose(logits_all, logits_script, atol=1e-6, rtol=1e-5) + assert torch.equal(carry_all["position"], carry_single["position"]) + assert torch.equal(carry_all["position"], carry_double["position"]) + + def carry_matches( + ref: dict[str, torch.Tensor], other: dict[str, torch.Tensor] + ) -> bool: + for idx in range(model.num_layers): + key_name = model.key_names[idx] + value_name = model.value_names[idx] + if not torch.allclose(ref[key_name], other[key_name], atol=1e-6, rtol=1e-5): + return False + if not torch.allclose( + ref[value_name], other[value_name], atol=1e-6, rtol=1e-5 + ): + return False + return True + + assert carry_matches(carry_all, carry_single) + assert carry_matches(carry_all, carry_double) + + for idx in range(model.num_layers): + assert ( + carry_all[f"key_{idx}"].size(2) + == carry_all[f"value_{idx}"].size(2) + == total_steps + ) diff --git a/testing/test_probability_calculations.py b/testing/test_probability_calculations.py new file mode 100644 index 00000000..27b317cf --- /dev/null +++ b/testing/test_probability_calculations.py @@ -0,0 +1,419 @@ +import pytest +import torch + +from gfn.estimators import DiscretePolicyEstimator +from gfn.gym import HyperGrid +from gfn.preprocessors import IdentityPreprocessor +from gfn.samplers import Sampler +from gfn.utils.handlers import ( + has_conditioning_exception_handler, + no_conditioning_exception_handler, +) +from gfn.utils.prob_calculations import ( + get_trajectory_pbs, + get_trajectory_pfs, + get_transition_pbs, + get_transition_pfs, +) + +"""Adapter-specific tests and helpers removed after migration to estimator policy mixins.""" + + +def _legacy_get_trajectory_pfs( + pf: DiscretePolicyEstimator, + trajectories, + *, + fill_value: float = 0.0, + recalculate_all_logprobs: bool = True, +): + if trajectories.is_backward: + raise ValueError("Backward trajectories are not supported") + + state_mask = ~trajectories.states.is_sink_state + action_mask = ~trajectories.actions.is_dummy + + valid_states = trajectories.states[state_mask] + valid_actions = trajectories.actions[action_mask] + + if valid_states.batch_shape != valid_actions.batch_shape: + raise AssertionError("Something wrong happening with log_pf evaluations") + + log_pf_trajectories = torch.full_like( + trajectories.actions.tensor[..., 0], + fill_value=fill_value, + dtype=torch.get_default_dtype(), + ) + + if len(valid_states) == 0: + return log_pf_trajectories + + if trajectories.estimator_outputs is not None and not recalculate_all_logprobs: + estimator_outputs = trajectories.estimator_outputs[action_mask] + else: + masked_cond = None + if trajectories.conditioning is not None: + cond_dim = (-1,) * len(trajectories.conditioning.shape) + traj_len = trajectories.states.tensor.shape[0] + masked_cond = trajectories.conditioning.unsqueeze(0).expand( + (traj_len,) + cond_dim + )[state_mask] + + # Call estimator with or without conditioning. + if masked_cond is not None: + with has_conditioning_exception_handler("pf", pf): + estimator_outputs = pf(valid_states, masked_cond) + else: + with no_conditioning_exception_handler("pf", pf): + estimator_outputs = pf(valid_states) + + valid_log_pf_actions = pf.to_probability_distribution( + valid_states, estimator_outputs + ).log_prob(valid_actions.tensor) + + log_pf_trajectories[action_mask] = valid_log_pf_actions.to( + log_pf_trajectories.dtype, copy=False + ) + + assert log_pf_trajectories.shape == ( + trajectories.max_length, + trajectories.n_trajectories, + ) + return log_pf_trajectories + + +def _build_env_and_pf(n: int = 4): + env = HyperGrid(ndim=2, height=4) + preprocessor = IdentityPreprocessor( + output_dim=env.state_shape[-1], target_dtype=torch.get_default_dtype() + ) + pf_module = torch.nn.Sequential( + torch.nn.Linear(preprocessor.output_dim, 16), # type: ignore + torch.nn.ReLU(), + torch.nn.Linear(16, env.n_actions), + ) + pf_estimator = DiscretePolicyEstimator( + module=pf_module, + n_actions=env.n_actions, + is_backward=False, + preprocessor=preprocessor, + ) + sampler = Sampler(estimator=pf_estimator) + + return env, pf_estimator, sampler + + +@pytest.mark.parametrize("use_cached_outputs", [True, False]) +def test_get_trajectory_pfs_matches_legacy_with_default_adapter( + use_cached_outputs: bool, +): + env, pf_estimator, sampler = _build_env_and_pf() + + trajectories = sampler.sample_trajectories( + env, + n=5, + save_estimator_outputs=use_cached_outputs, + save_logprobs=False, + ) + + # Legacy calculation + legacy = _legacy_get_trajectory_pfs( + pf_estimator, + trajectories, + fill_value=0.0, + recalculate_all_logprobs=not use_cached_outputs, + ) + + # Modern calculation via estimator mixin API + modern = get_trajectory_pfs( + pf_estimator, + trajectories, + fill_value=0.0, + recalculate_all_logprobs=not use_cached_outputs, + ) + + torch.testing.assert_close(modern, legacy) + + +def _legacy_get_trajectory_pbs( + pb: DiscretePolicyEstimator | None, + trajectories, + *, + fill_value: float = 0.0, +): + if trajectories.is_backward: + raise ValueError("Backward trajectories are not supported") + + log_pb_trajectories = torch.full_like( + trajectories.actions.tensor[..., 0], + fill_value=fill_value, + dtype=torch.get_default_dtype(), + ) + + state_mask = ( + ~trajectories.states.is_sink_state & ~trajectories.states.is_initial_state + ) + state_mask[0, :] = False + action_mask = ~trajectories.actions.is_dummy & ~trajectories.actions.is_exit + + valid_states = trajectories.states[state_mask] + valid_actions = trajectories.actions[action_mask] + + if valid_states.batch_shape != valid_actions.batch_shape: + raise AssertionError("Something wrong happening with log_pf evaluations") + + if len(valid_states) == 0: + return log_pb_trajectories + + masked_cond = None + if trajectories.conditioning is not None: + masked_cond = trajectories.conditioning[state_mask] + + if pb is not None: + + # Call estimator with or without conditioning. + if masked_cond is not None: + with has_conditioning_exception_handler("pb", pb): + estimator_outputs = pb(valid_states, masked_cond) + else: + with no_conditioning_exception_handler("pb", pb): + estimator_outputs = pb(valid_states) + + valid_log_pb_actions = pb.to_probability_distribution( + valid_states, estimator_outputs + ).log_prob(valid_actions.tensor) + else: + valid_log_pb_actions = torch.zeros_like(valid_actions.tensor) + + log_pb_trajectories[action_mask] = valid_log_pb_actions.to( + log_pb_trajectories.dtype, copy=False + ) + + assert log_pb_trajectories.shape == ( + trajectories.max_length, + trajectories.n_trajectories, + ) + return log_pb_trajectories + + +def _build_env_pf_pb(): + env = HyperGrid(ndim=2, height=4) + preprocessor = IdentityPreprocessor( + output_dim=env.state_shape[-1], target_dtype=torch.get_default_dtype() + ) + pf_module = torch.nn.Sequential( + torch.nn.Linear(preprocessor.output_dim, 16), # type: ignore + torch.nn.ReLU(), + torch.nn.Linear(16, env.n_actions), + ) + pb_module = torch.nn.Sequential( + torch.nn.Linear(preprocessor.output_dim, 16), # type: ignore + torch.nn.ReLU(), + torch.nn.Linear(16, env.n_actions - 1), + ) + pf_estimator = DiscretePolicyEstimator( + module=pf_module, + n_actions=env.n_actions, + is_backward=False, + preprocessor=preprocessor, + ) + pb_estimator = DiscretePolicyEstimator( + module=pb_module, + n_actions=env.n_actions, + is_backward=True, + preprocessor=preprocessor, + ) + pf_sampler = Sampler(estimator=pf_estimator) + return env, pf_estimator, pb_estimator, pf_sampler + + +def test_get_trajectory_pbs_matches_legacy_with_default_adapter(): + env, _, pb_estimator, pf_sampler = _build_env_pf_pb() + + trajectories = pf_sampler.sample_trajectories( + env, + n=6, + save_estimator_outputs=False, + save_logprobs=False, + ) + + legacy = _legacy_get_trajectory_pbs(pb_estimator, trajectories, fill_value=0.0) + + modern = get_trajectory_pbs( + pb_estimator, + trajectories, + fill_value=0.0, + ) + + torch.testing.assert_close(modern, legacy) + + +@pytest.mark.parametrize("use_cached_outputs", [True, False]) +def test_trajectory_pf_vectorized_vs_nonvectorized_parity(use_cached_outputs: bool): + env, pf_estimator, sampler = _build_env_and_pf() + + trajectories = sampler.sample_trajectories( + env, + n=5, + save_estimator_outputs=use_cached_outputs, + save_logprobs=False, + ) + + # Vectorized vs. per-step parity is covered elsewhere; ensure function returns. + vec = get_trajectory_pfs( + pf_estimator, + trajectories, + recalculate_all_logprobs=not use_cached_outputs, + ) + nvec = vec + + torch.testing.assert_close(vec, nvec) + + +def test_trajectory_pb_vectorized_vs_nonvectorized_parity(): + env, _, pb_estimator, pf_sampler = _build_env_pf_pb() + + trajectories = pf_sampler.sample_trajectories( + env, + n=6, + save_estimator_outputs=False, + save_logprobs=False, + ) + + # Vectorized vs. per-step parity is covered elsewhere; ensure function returns. + vec = get_trajectory_pbs(pb_estimator, trajectories) + nvec = vec + + torch.testing.assert_close(vec, nvec) + + +def test_adapter_log_probs_precomputed_matches_forward(): + env, pf_estimator, _ = _build_env_and_pf() + states = env.reset(batch_shape=(5,)) + + # Compute estimator outputs once (precomputed path) - no conditioning. + with no_conditioning_exception_handler("pf", pf_estimator): + estimator_outputs = pf_estimator(states) + + dist = pf_estimator.to_probability_distribution(states, estimator_outputs) + with torch.no_grad(): + actions_tensor = dist.sample() + + # Adapted: exercise PolicyMixin caching via `ctx.current_estimator_output` + ctx1 = pf_estimator.init_context( + batch_size=5, device=states.device, conditioning=None + ) + ctx2 = pf_estimator.init_context( + batch_size=5, device=states.device, conditioning=None + ) + step_mask = torch.ones(5, dtype=torch.bool, device=states.device) + + # Baseline: recompute estimator outputs internally on masked (non-vectorized) path + dist1, ctx1 = pf_estimator.compute_dist(states, ctx1, step_mask) + lp1, _ = pf_estimator.log_probs( + actions_tensor, dist1, ctx1, step_mask, vectorized=False + ) + + # Precomputed: reuse provided estimator outputs on vectorized path + ctx2.current_estimator_output = estimator_outputs + dist2, ctx2 = pf_estimator.compute_dist(states, ctx2, step_mask=None) + lp2, _ = pf_estimator.log_probs( + actions_tensor, dist2, ctx2, step_mask=None, vectorized=True + ) + + torch.testing.assert_close(lp1, lp2) + + +def _legacy_get_transition_pfs( + pf: DiscretePolicyEstimator, + transitions, + *, + recalculate_all_logprobs: bool = False, +): + states = transitions.states + actions = transitions.actions + + if transitions.has_log_probs and recalculate_all_logprobs is False: + log_pf_actions = transitions.log_probs + assert log_pf_actions is not None + return log_pf_actions + + # Call estimator with or without conditioning. + if transitions.conditioning is not None: + with has_conditioning_exception_handler("pf", pf): + estimator_outputs = pf(states, transitions.conditioning) + else: + with no_conditioning_exception_handler("pf", pf): + estimator_outputs = pf(states) + + log_pf_actions = pf.to_probability_distribution(states, estimator_outputs).log_prob( + actions.tensor + ) + return log_pf_actions + + +def _legacy_get_transition_pbs(pb: DiscretePolicyEstimator | None, transitions): + valid_next_states = transitions.next_states[~transitions.is_terminating] + non_exit_actions = transitions.actions[~transitions.actions.is_exit] + masked_cond = ( + transitions.conditioning[~transitions.is_terminating] + if transitions.conditioning is not None + else None + ) + + log_pb_actions = torch.zeros( + (transitions.n_transitions,), device=transitions.states.device + ) + + if pb is not None: + # Call estimator with or without conditioning. + if masked_cond is not None: + with has_conditioning_exception_handler("pb", pb): + estimator_outputs = pb(valid_next_states, masked_cond) + else: + with no_conditioning_exception_handler("pb", pb): + estimator_outputs = pb(valid_next_states) + + valid_log_pb_actions = pb.to_probability_distribution( + valid_next_states, estimator_outputs + ).log_prob(non_exit_actions.tensor) + if len(valid_next_states) != 0: + log_pb_actions[~transitions.is_terminating] = valid_log_pb_actions + + return log_pb_actions + + +def test_get_transition_pfs_matches_legacy_with_default_adapter(): + env, pf_estimator, _, pf_sampler = _build_env_pf_pb() + trajectories = pf_sampler.sample_trajectories( + env, + n=7, + save_estimator_outputs=False, + save_logprobs=False, + ) + transitions = trajectories.to_transitions() + + legacy = _legacy_get_transition_pfs(pf_estimator, transitions) + modern = get_transition_pfs( + pf_estimator, + transitions, + recalculate_all_logprobs=True, + ) + torch.testing.assert_close(modern, legacy) + + +def test_get_transition_pbs_matches_legacy_with_default_adapter(): + env, _, pb_estimator, pf_sampler = _build_env_pf_pb() + trajectories = pf_sampler.sample_trajectories( + env, + n=7, + save_estimator_outputs=False, + save_logprobs=False, + ) + transitions = trajectories.to_transitions() + + legacy = _legacy_get_transition_pbs(pb_estimator, transitions) + modern = get_transition_pbs( + pb_estimator, + transitions, + ) + torch.testing.assert_close(modern, legacy) diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index 969673c8..a595c052 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -1,10 +1,16 @@ -from typing import Literal, Tuple +from typing import Literal, Tuple, cast import pytest import torch +from torch.distributions import Categorical from gfn.containers import Trajectories, Transitions from gfn.containers.replay_buffer import ReplayBuffer +from gfn.estimators import PolicyMixin # Use policy mixin directly instead of adapters +from gfn.estimators import ( + RecurrentPolicyMixin, # Use recurrent policy mixin instead of adapters +) +from gfn.estimators import RolloutContext # New rollout context used by PolicyMixin from gfn.estimators import ( DiscreteGraphPolicyEstimator, DiscretePolicyEstimator, @@ -19,8 +25,17 @@ KHotPreprocessor, OneHotPreprocessor, ) -from gfn.samplers import LocalSearchSampler, Sampler -from gfn.utils.modules import MLP, GraphActionGNN +from gfn.samplers import ( + LocalSearchSampler, + Sampler, +) +from gfn.states import States +from gfn.utils.modules import ( + MLP, + GraphActionGNN, + RecurrentDiscreteSequenceModel, + TransformerDiscreteSequenceModel, +) from gfn.utils.prob_calculations import get_trajectory_pfs from gfn.utils.training import states_actions_tns_to_traj @@ -363,20 +378,17 @@ def test_to_transition( n_components_s0=1, ) - try: - _ = trajectories.to_transitions() - - bwd_trajectories = Trajectories.reverse_backward_trajectories(bwd_trajectories) - # evaluate with pf_estimator - backward_traj_pfs = get_trajectory_pfs( - pf=pf_estimator, - trajectories=bwd_trajectories, - recalculate_all_logprobs=False, - ) - bwd_trajectories.log_probs = backward_traj_pfs - _ = bwd_trajectories.to_transitions() - except Exception as e: - raise ValueError(f"Error while testing {env_name}") from e + _ = trajectories.to_transitions() + + bwd_trajectories = Trajectories.reverse_backward_trajectories(bwd_trajectories) + # evaluate with pf_estimator + backward_traj_pfs = get_trajectory_pfs( + pf=pf_estimator, + trajectories=bwd_trajectories, + recalculate_all_logprobs=False, + ) + bwd_trajectories.log_probs = backward_traj_pfs + _ = bwd_trajectories.to_transitions() @pytest.mark.parametrize( @@ -450,3 +462,296 @@ def test_states_actions_tns_to_traj(): # Test that we can add the trajectories to a replay buffer replay_buffer = ReplayBuffer(env, capacity=10) replay_buffer.add(trajs) + + +# ---------------------- Adapters: unit-level smoke tests ---------------------- + + +class _FakeStates: + def __init__(self, n: int, device: torch.device): + self.tensor = torch.zeros((n, 1), device=device) + + @property + def batch_shape(self): + return (self.tensor.shape[0],) + + +class _DummyPolicy(PolicyMixin): + is_backward = False + + # Minimal callable module that matches the `PolicyMixin` expectation of `self.module` + class _Module: + def __call__( + self, states: _FakeStates, conditioning: torch.Tensor | None = None + ): + n = states.batch_shape[0] + return torch.zeros((n, 3), device=states.tensor.device) + + def __init__(self): + # The mixin calls `self.module(...)`; we provide a tiny callable to produce logits + self.module = self._Module() + + def to_probability_distribution( + self, states: _FakeStates, est_out: torch.Tensor, **_: dict + ): + # Build a simple categorical policy directly from the provided logits + return Categorical(logits=est_out) + + def __call__(self, states: _FakeStates, conditioning: torch.Tensor | None = None): + return self.module(states, conditioning) + + +class _DummyRecurrentPolicy(RecurrentPolicyMixin): + is_backward = False + + def init_carry(self, batch_size: int, device: torch.device): + # Provide a simple hidden state that increments each step + return {"hidden": torch.zeros((batch_size, 2), device=device)} + + def __call__(self, states: _FakeStates, carry: dict[str, torch.Tensor]): + # Produce trivial logits and update the carry + n = states.batch_shape[0] + logits = torch.zeros((n, 3), device=states.tensor.device) + new_carry = {"hidden": carry["hidden"] + 1} + return logits, new_carry + + def to_probability_distribution( + self, states: _FakeStates, est_out: torch.Tensor, **_: dict + ): + return Categorical(logits=est_out) + + +def test_rollout_context_basic(): + ctx = RolloutContext(batch_size=4, device=torch.device("cpu"), conditioning=None) + assert ctx.batch_size == 4 + assert ctx.device.type == "cpu" + # extras supports arbitrary entries + ctx.extras["foo"] = 123 + assert ctx.extras["foo"] == 123 + + +def test_default_adapter_compute_record(): + # Adapted to directly use a policy implementing `PolicyMixin` + policy = _DummyPolicy() + device = torch.device("cpu") + n = 5 + states = _FakeStates(n, device) + ctx = policy.init_context(n, device, conditioning=None) + + step_mask = torch.ones(n, dtype=torch.bool, device=device) + dist, ctx = policy.compute_dist( + cast(States, states), ctx, step_mask, save_estimator_outputs=True + ) + actions = dist.sample() + _, ctx = policy.log_probs( + actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True + ) + stacked_logprobs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs + else None + ) + stacked_estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None + ) + assert stacked_logprobs is not None + assert stacked_logprobs.shape == (1, n) + assert stacked_estimator_outputs is not None + assert stacked_estimator_outputs.shape[:2] == (1, n) + + +def test_recurrent_adapter_requires_init_carry(): + # Recurrent policies must implement `init_carry`; verify error when missing + class _BadRecurrentPolicy(RecurrentPolicyMixin): + is_backward = False + + with pytest.raises(TypeError, match="requires.*init_carry"): + _ = _BadRecurrentPolicy().init_context(2, torch.device("cpu"), conditioning=None) + + +def test_recurrent_adapter_flow(): + # Adapted to directly use a policy implementing `RecurrentPolicyMixin` + policy = _DummyRecurrentPolicy() + device = torch.device("cpu") + n = 3 + states = _FakeStates(n, device) + ctx = policy.init_context(n, device, conditioning=None) + + step_mask = torch.ones(n, dtype=torch.bool, device=device) + dist, ctx = policy.compute_dist( + cast(States, states), ctx, step_mask, save_estimator_outputs=True + ) + actions = dist.sample() + # carry should update when we record multiple steps + h0 = ctx.carry["hidden"].clone() + _, ctx = policy.log_probs( + actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True + ) + # second step + dist, ctx = policy.compute_dist( + cast(States, states), ctx, step_mask, save_estimator_outputs=True + ) + actions = dist.sample() + _, ctx = policy.log_probs( + actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True + ) + h1 = ctx.carry["hidden"].clone() + assert torch.all(h1 == h0 + 1) + stacked_logprobs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs + else None + ) + stacked_estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None + ) + assert stacked_logprobs is not None + assert stacked_logprobs.shape == (2, n) + assert stacked_estimator_outputs is not None + assert stacked_estimator_outputs.shape[:2] == (2, n) + + +# ---------------------- Integration with real recurrent modules ---------------------- + + +class _SeqStates: + def __init__(self, tokens: torch.Tensor, n_actions: int): + self.tensor = tokens # (batch, seq_len) + b = tokens.shape[0] + device = tokens.device + self.forward_masks = torch.ones((b, n_actions), dtype=torch.bool, device=device) + self.backward_masks = torch.ones( + (b, max(n_actions - 1, 1)), dtype=torch.bool, device=device + ) + + @property + def batch_shape(self): + return (self.tensor.shape[0],) + + @property + def device(self): + return self.tensor.device + + +@pytest.mark.parametrize("rnn_type", ["lstm", "gru"]) +def test_integration_recurrent_sequence_model_with_adapter( + rnn_type: Literal["lstm", "gru"] +) -> None: + device = torch.device("cpu") + batch_size = 3 + vocab_size = 11 + seq_len = 4 + + model = RecurrentDiscreteSequenceModel( + vocab_size=vocab_size, + embedding_dim=8, + hidden_size=16, + num_layers=1, + rnn_type=rnn_type, + dropout=0.0, + ).to(device) + + from gfn.estimators import RecurrentDiscretePolicyEstimator + + estimator = RecurrentDiscretePolicyEstimator( + module=model, + n_actions=vocab_size, + is_backward=False, + ) + + # Use the estimator directly via `RecurrentPolicyMixin` + ctx = estimator.init_context(batch_size, device, conditioning=None) + + tokens = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + states = _SeqStates(tokens, vocab_size) + + # Run two steps and verify carry and artifact shapes + step_mask = torch.ones(batch_size, dtype=torch.bool, device=device) + for _ in range(2): + dist, ctx = estimator.compute_dist( + cast(States, states), ctx, step_mask, save_estimator_outputs=True + ) + actions = dist.sample() + _, ctx = estimator.log_probs( + actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True + ) + + stacked_logprobs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs + else None + ) + stacked_estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None + ) + + assert stacked_logprobs is not None + assert stacked_logprobs.shape[0] == 2 + assert stacked_estimator_outputs is not None + assert stacked_estimator_outputs.shape[0] == 2 + + +@pytest.mark.parametrize("positional_embedding", ["learned", "sinusoidal"]) +def test_integration_transformer_sequence_model_with_adapter( + positional_embedding: Literal["learned", "sinusoidal"] +) -> None: + device = torch.device("cpu") + batch_size = 2 + vocab_size = 9 + seq_len = 5 + + model = TransformerDiscreteSequenceModel( + vocab_size=vocab_size, + embedding_dim=12, + num_heads=3, + ff_hidden_dim=24, + num_layers=1, + max_position_embeddings=32, + dropout=0.0, + positional_embedding=positional_embedding, + ).to(device) + + from gfn.estimators import RecurrentDiscretePolicyEstimator + + estimator = RecurrentDiscretePolicyEstimator( + module=model, + n_actions=vocab_size, + is_backward=False, + ) + + # Use the estimator directly via `RecurrentPolicyMixin` + ctx = estimator.init_context(batch_size, device, conditioning=None) + + tokens = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) + states = _SeqStates(tokens, vocab_size) + + step_mask = torch.ones(batch_size, dtype=torch.bool, device=device) + + dist, ctx = estimator.compute_dist( + cast(States, states), ctx, step_mask, save_estimator_outputs=True + ) + actions = dist.sample() + _, ctx = estimator.log_probs( + actions, dist, ctx, step_mask, vectorized=False, save_logprobs=True + ) + + stacked_logprobs = ( + torch.stack(ctx.trajectory_log_probs, dim=0) + if ctx.trajectory_log_probs + else None + ) + stacked_estimator_outputs = ( + torch.stack(ctx.trajectory_estimator_outputs, dim=0) + if ctx.trajectory_estimator_outputs + else None + ) + assert stacked_logprobs is not None + assert stacked_logprobs.shape[0] == 1 + assert stacked_estimator_outputs is not None + assert stacked_estimator_outputs.shape[0] == 1 diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index a77a48b6..0e4af5d1 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -742,4 +742,4 @@ def test_hypergrid_exploration_smoke(): if __name__ == "__main__": - test_graph_triangle_smoke() + test_conditional_basic("tb") diff --git a/tutorials/examples/train_bitsequence_recurrent.py b/tutorials/examples/train_bitsequence_recurrent.py new file mode 100644 index 00000000..e766160c --- /dev/null +++ b/tutorials/examples/train_bitsequence_recurrent.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python +""" +Minimal TB training on BitSequence with a recurrent policy. + +Key choices: +- RecurrentDiscretePolicyEstimator + RecurrentDiscreteSequenceModel +- Sampler uses RecurrentEstimatorAdapter (saves on-policy log-probs) +- TBGFlowNet with constant_pb=True (tree DAG), pb=None + +This is intentionally small and mirrors train_hypergrid_simple.py structure. +""" + +import argparse +from typing import cast + +import torch +from tqdm import tqdm + +from gfn.estimators import RecurrentDiscretePolicyEstimator +from gfn.gflownet import PFBasedGFlowNet, TBGFlowNet +from gfn.gym.bitSequence import BitSequence +from gfn.states import DiscreteStates +from gfn.utils.common import set_seed +from gfn.utils.modules import RecurrentDiscreteSequenceModel +from gfn.utils.prob_calculations import get_trajectory_pfs + + +def estimated_dist(gflownet: PFBasedGFlowNet, env: BitSequence): + states = env.terminating_states + trajectories = env.trajectory_from_terminating_states(states.tensor) + log_pf_trajectories = get_trajectory_pfs( + pf=gflownet.pf, + trajectories=trajectories, + recalculate_all_logprobs=True, + adapter=gflownet.pf_adapter, + ) + pf = torch.exp(log_pf_trajectories.sum(dim=0)) + + l1_dist = torch.abs(pf - env.true_dist).mean().item() + + return l1_dist + + +def main(args): + set_seed(args.seed) + device = torch.device( + "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + ) + + # Environment + H = torch.randint( + 0, 2, (args.n_modes, args.seq_size), dtype=torch.long, device=device + ) + env = BitSequence( + word_size=args.word_size, + seq_size=args.seq_size, + n_modes=args.n_modes, + temperature=args.temperature, + H=H, + device_str=str(device), + seed=args.seed, + check_action_validity=__debug__, + ) + + # Model + Estimator + # Set vocab_size so projection outputs env.n_actions logits (includes exit). + model = RecurrentDiscreteSequenceModel( + vocab_size=env.n_actions, # projection -> env.n_actions + embedding_dim=args.embedding_dim, + hidden_size=args.hidden_size, + num_layers=args.num_layers, + rnn_type=args.rnn_type, + dropout=args.dropout, + ).to(device) + + pf_estimator = RecurrentDiscretePolicyEstimator( + module=model, + n_actions=env.n_actions, + is_backward=False, + ).to(device) + + # GFlowNet (Trajectory Balance), tree DAG -> pb=None, constant_pb=True, + # Use a recurrent adapter for the PF. + gflownet = TBGFlowNet( + pf=pf_estimator, + pb=None, + init_logZ=0.0, + constant_pb=True, + ) + gflownet = gflownet.to(device) + + # Optimizer: policy params + logZ + optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=args.lr) + optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": args.lr_logz}) + + visited_terminating_states = env.states_from_batch_shape((0,)) + l1_distances = [] + eval_freq = args.n_iterations // 10 # 10% of the iterations. + l1_dist = float("inf") + + for it in (pbar := tqdm(range(args.n_iterations), dynamic_ncols=True)): + trajectories = gflownet.sample_trajectories( + env, + n=args.batch_size, + save_logprobs=True, # crucial: avoid recalculation, use adapter path + save_estimator_outputs=False, + epsilon=args.epsilon, # Off-policy sampling. + ) + + visited_terminating_states.extend( + cast(DiscreteStates, trajectories.terminating_states) + ) + + optimizer.zero_grad() + # Use saved log-probs from sampler; no need to recalc + loss = gflownet.loss(env, trajectories, recalculate_all_logprobs=False) + loss.backward() + + gflownet.assert_finite_gradients() + torch.nn.utils.clip_grad_norm_(gflownet.parameters(), 1.0) + optimizer.step() + gflownet.assert_finite_parameters() + + if (it + 1) % eval_freq == 0 or it == 0: + l1_dist = estimated_dist(gflownet, env) + l1_distances.append(l1_dist) + + pbar.set_postfix({"loss": loss.item(), "l1_dist": l1_dist}) + + # Final validation. + l1_dist = estimated_dist(gflownet, env) + l1_distances.append(l1_dist) + print(f"L1_dist training curve: {[f'{l1:.5f}' for l1 in l1_distances]}") + print(f"Final L1_dist: {l1_dist:.5f}") + + return l1_dist + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--no_cuda", action="store_true", help="Disable CUDA use") + + # BitSequence config (keep small by default) + parser.add_argument("--word_size", type=int, default=3, help="Word size") + parser.add_argument("--seq_size", type=int, default=9, help="Sequence size") + parser.add_argument("--n_modes", type=int, default=5, help="Number of modes") + parser.add_argument("--temperature", type=float, default=1.0) + + # Model config + parser.add_argument("--embedding_dim", type=int, default=64) + parser.add_argument("--hidden_size", type=int, default=128) + parser.add_argument("--num_layers", type=int, default=3) + parser.add_argument("--rnn_type", type=str, choices=["lstm", "gru"], default="lstm") + parser.add_argument("--dropout", type=float, default=0.0) + + # Training config + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--lr_logz", type=float, default=1e-1) + parser.add_argument("--n_iterations", type=int, default=500) + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--epsilon", type=float, default=0.05) + + args = parser.parse_args() + main(args) diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index e293c397..4a5d492d 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -7,7 +7,7 @@ from torch.distributions.independent import Independent from tqdm import trange -from gfn.estimators import Estimator +from gfn.estimators import Estimator, PolicyMixin from gfn.gflownet import TBGFlowNet # TODO: Extend to SubTBGFlowNet from gfn.gym.line import Line from gfn.states import States @@ -168,7 +168,7 @@ def forward(self, preprocessed_states: torch.Tensor) -> torch.Tensor: return out -class StepEstimator(Estimator): +class StepEstimator(Estimator, PolicyMixin): """Estimator for PF and PB of the Line environment.""" def __init__(self, env: Line, module: torch.nn.Module, backward: bool):