From c96cb1acffe1d0ce3ccffdbcb601c274d30503db Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 7 Jul 2025 21:00:38 -0700 Subject: [PATCH 01/11] infilling --- tutorials/examples/test_scripts.py | 20 ++++ tutorials/examples/train_rng_gfn.py | 155 ++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+) create mode 100644 tutorials/examples/train_rng_gfn.py diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 01e933c7..f8f6e4ba 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -20,6 +20,7 @@ from .train_hypergrid_simple_ls import main as train_hypergrid_simple_ls_main from .train_ising import main as train_ising_main from .train_line import main as train_line_main +from .train_rng_gfn import main as train_rng_gfn_main @dataclass @@ -170,6 +171,25 @@ class BayesianStructureArgs(CommonArgs): use_cuda: bool = False +@dataclass +class RNGGFNArgs(CommonArgs): + batch_size: int = 8 + n_trajectories: int = 8 + n_iterations: int = 10 + lr: float = 1e-4 + max_length: int = 5 + prompt: str = "The following is a random integer drawn uniformly between 0 and 100: " + + +@pytest.mark.parametrize("n_iterations", [10]) +def test_rng_gfn_smoke(n_iterations: int): + """Smoke test for the RNG GFN training script.""" + args = RNGGFNArgs(n_iterations=n_iterations) + args_dict = asdict(args) + namespace_args = Namespace(**args_dict) + train_rng_gfn_main(namespace_args) # Just ensure it runs without errors. + + @pytest.mark.parametrize("ndim", [2, 4]) @pytest.mark.parametrize("height", [8, 16]) @pytest.mark.parametrize("replay_buffer_size", [0, 1000]) diff --git a/tutorials/examples/train_rng_gfn.py b/tutorials/examples/train_rng_gfn.py new file mode 100644 index 00000000..db4e40de --- /dev/null +++ b/tutorials/examples/train_rng_gfn.py @@ -0,0 +1,155 @@ + +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoTokenizer +import numpy as np + +from gfn.env import Env +from gfn.states import DiscreteStates +from gfn.modules import GFNModule +from gfn.gflownet import GFlowNet, SubTrajectoryBalance +from gfn.samplers import Sampler + + + +# 1. Environment Definition +class RNGEnv(Env): + def __init__(self, tokenizer, prompt, max_length=5, device='cuda'): + self.tokenizer = tokenizer + self.prompt = prompt + self.prompt_tokens = tokenizer.encode(prompt, return_tensors='pt').to(device) + self.max_length = max_length + self.device = device + + s0 = DiscreteStates(self.prompt_tokens.squeeze(0)) + sf = DiscreteStates(torch.tensor([self.tokenizer.eos_token_id], device=self.device)) + + super().__init__(s0=s0, sf=sf, n_actions=tokenizer.vocab_size) + + def get_actions_masks(self, states: DiscreteStates) -> torch.Tensor: + masks = torch.ones(len(states), self.n_actions, dtype=torch.bool, device=self.device) + for i in range(len(states)): + if states.masks[i].sum() >= self.max_length + self.prompt_tokens.shape[1]: + masks[i, :] = False + masks[i, self.tokenizer.eos_token_id] = True + return masks + + def forward_step(self, states: DiscreteStates, actions: torch.Tensor) -> DiscreteStates: + new_states_list = [] + for i in range(len(states)): + state_tensor = states.tensor[i][states.masks[i]] + action = actions[i] + new_state_tensor = torch.cat([state_tensor, action.unsqueeze(0)]) + new_states_list.append(new_state_tensor) + return DiscreteStates(new_states_list) + + def log_reward(self, states: DiscreteStates) -> torch.Tensor: + rewards = torch.full((len(states),), -20.0, device=self.device) + + prompt_len = self.prompt_tokens.shape[1] + + sequences_to_decode = [] + valid_indices = [] + for i in range(len(states)): + if states.masks[i].sum() > prompt_len: + seq = states.tensor[i][states.masks[i]] + # Only reward terminal states + if seq[-1] == self.tokenizer.eos_token_id: + sequences_to_decode.append(seq[prompt_len:-1]) + valid_indices.append(i) + + if not sequences_to_decode: + return rewards + + decoded_texts = self.tokenizer.batch_decode(sequences_to_decode, skip_special_tokens=True) + + for i, decoded_text in enumerate(decoded_texts): + original_index = valid_indices[i] + try: + number = int(decoded_text.strip()) + if 0 <= number <= 100: + rewards[original_index] = torch.log(torch.tensor(1.0 / 101.0, device=self.device)) + except (ValueError, IndexError): + pass + + return rewards + + def reset(self, batch_size: int = 1) -> DiscreteStates: + return DiscreteStates([self.s0.tensor for _ in range(batch_size)]) + +# 2. GFN Module (wrapping the LLM) +class LLMGFNModule(GFNModule): + def __init__(self, model, tokenizer): + super().__init__() + self.model = model + self.tokenizer = tokenizer + self.log_z = nn.Parameter(torch.tensor(0.0)) + + def forward(self, states: States): + input_ids = states.tensor + attention_mask = states.masks + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) + + sequence_lengths = attention_mask.sum(dim=1) + last_token_logits = outputs.logits[torch.arange(len(outputs.logits)), sequence_lengths - 1, :] + + return last_token_logits, self.log_z + +# 3. Training Setup +def main(): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + print(f"Using device: {device}") + + tokenizer = AutoTokenizer.from_pretrained('gpt2') + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained('gpt2').to(device) + + prompt = "The following is a random integer drawn uniformly between 0 and 100: " + env = RNGEnv(tokenizer, prompt, device=device, max_length=5) + gfn_module = LLMGFNModule(model, tokenizer) + gflownet = GFlowNet(gfn_module, loss=SubTrajectoryBalance()) + sampler = Sampler(gfn_module, env) + + optimizer = torch.optim.Adam(gfn_module.parameters(), lr=1e-4) + + print("Starting training...") + # 4. Training Loop + for i in range(501): # A short training loop for demonstration + trajectories = sampler.sample_trajectories(n_trajectories=8) + loss = gflownet.loss(trajectories) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if i % 100 == 0: + print(f"Step {i}, Loss: {loss.item()}") + + # 5. Evaluation + print("\n--- Evaluation ---") + with torch.no_grad(): + trajectories = sampler.sample_trajectories(n_trajectories=100) + final_states = trajectories.last_states + + numbers = [] + for i in range(len(final_states)): + state_tensor = final_states.tensor[i][final_states.masks[i]] + generated_tokens = state_tensor[len(env.prompt_tokens[0]):-1] # Exclude prompt and EOS + decoded_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) + try: + numbers.append(int(decoded_text.strip())) + except (ValueError, IndexError): + pass + + print(f"Generated numbers: {numbers}") + if numbers: + counts = np.bincount(numbers, minlength=101) + print(f"Counts of numbers 0-100: {counts}") + print(f"Number of valid samples: {len(numbers)}") + else: + print("No valid numbers generated.") + +if __name__ == '__main__': + main() From c605aa02826beb730a554dbe004bac1aefc8754e Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 13 Jul 2025 19:13:24 -0700 Subject: [PATCH 02/11] fix env issues --- tutorials/examples/train_rng_gfn.py | 323 +++++++++++++++++++++------- 1 file changed, 242 insertions(+), 81 deletions(-) diff --git a/tutorials/examples/train_rng_gfn.py b/tutorials/examples/train_rng_gfn.py index db4e40de..3f288e09 100644 --- a/tutorials/examples/train_rng_gfn.py +++ b/tutorials/examples/train_rng_gfn.py @@ -1,101 +1,250 @@ import torch import torch.nn as nn +from typing import cast from transformers import AutoModelForCausalLM, AutoTokenizer import numpy as np -from gfn.env import Env +from gfn.env import DiscreteEnv +from gfn.actions import Actions from gfn.states import DiscreteStates from gfn.modules import GFNModule -from gfn.gflownet import GFlowNet, SubTrajectoryBalance -from gfn.samplers import Sampler +from gfn.preprocessors import Preprocessor +from gfn.gflownet import TBGFlowNet +class RNGEnv(DiscreteEnv): + """Environment that builds a number token-by-token after a fixed prompt. -# 1. Environment Definition -class RNGEnv(Env): - def __init__(self, tokenizer, prompt, max_length=5, device='cuda'): + A state is a fixed-length tensor of token ids consisting of the prompt + followed by up to ``max_length`` generated tokens. Padding is done with the + tokenizer pad token id. The episode terminates when an ``eos`` token is + generated **or** when the maximum length is reached (in which case only the + ``eos`` token is allowed). The reward is the uniform probability over the + integers 0-100. + """ + + def __init__(self, tokenizer, prompt, max_length: int = 5, device: str | torch.device = "cuda"): self.tokenizer = tokenizer self.prompt = prompt - self.prompt_tokens = tokenizer.encode(prompt, return_tensors='pt').to(device) + device = torch.device(device) + + # Prompt tokens (1, prompt_len) + self.prompt_tokens = tokenizer.encode(prompt, return_tensors="pt").to(device) + prompt_len = self.prompt_tokens.shape[1] + + # Fixed state length (prompt + generated tokens). self.max_length = max_length - self.device = device - - s0 = DiscreteStates(self.prompt_tokens.squeeze(0)) - sf = DiscreteStates(torch.tensor([self.tokenizer.eos_token_id], device=self.device)) - - super().__init__(s0=s0, sf=sf, n_actions=tokenizer.vocab_size) - - def get_actions_masks(self, states: DiscreteStates) -> torch.Tensor: - masks = torch.ones(len(states), self.n_actions, dtype=torch.bool, device=self.device) - for i in range(len(states)): - if states.masks[i].sum() >= self.max_length + self.prompt_tokens.shape[1]: - masks[i, :] = False - masks[i, self.tokenizer.eos_token_id] = True + self.total_length = prompt_len + max_length # total length of the tensor + + self.state_shape = self.total_length # keep name for clarity before tuple wrap + + # Initial state: prompt followed by padding. + s0 = torch.nn.functional.pad( + self.prompt_tokens.squeeze(0), + (0, max_length), + value=tokenizer.pad_token_id, + ) # (state_shape,) + + # Sink state: only the EOS token followed by padding. + sf = torch.nn.functional.pad( + torch.tensor([tokenizer.eos_token_id], device=device), + (0, self.total_length - 1), + value=tokenizer.pad_token_id, + ) + + # We'll treat actions as scalars (shape = ()), so configure accordingly. + super().__init__( + s0=s0, + sf=sf, + n_actions=tokenizer.vocab_size, + state_shape=(self.total_length,), + action_shape=(), + dummy_action=torch.tensor(-1, device=device), + exit_action=torch.tensor(tokenizer.eos_token_id, device=device), + ) + + # --------------------------------------------------------------------- + # Masks helpers + # --------------------------------------------------------------------- + def _forward_action_masks(self, states: DiscreteStates) -> torch.Tensor: + """Returns a bool mask (*batch, n_actions) of valid *forward* actions.""" + batch_size = states.batch_shape[0] + masks = torch.ones((batch_size, self.n_actions), dtype=torch.bool, device=self.device) + + # Current sequence length (non-pad tokens). + seq_lens = (states.tensor != self.tokenizer.pad_token_id).sum(dim=1) + + # If the state is already full forbid everything but EOS. + full_idx = seq_lens >= self.total_length + if full_idx.any(): + masks[full_idx] = False + masks[full_idx, self.tokenizer.eos_token_id] = True + return masks - def forward_step(self, states: DiscreteStates, actions: torch.Tensor) -> DiscreteStates: - new_states_list = [] - for i in range(len(states)): - state_tensor = states.tensor[i][states.masks[i]] - action = actions[i] - new_state_tensor = torch.cat([state_tensor, action.unsqueeze(0)]) - new_states_list.append(new_state_tensor) - return DiscreteStates(new_states_list) + def update_masks(self, states: DiscreteStates) -> None: # type: ignore[override] + """Populate ``states.forward_masks`` and ``states.backward_masks`` in-place.""" + + # Forward masks (valid next tokens). + states.forward_masks = self._forward_action_masks(states) + # Backward masks exclude the exit action (EOS). + backward_masks = torch.ones((*states.batch_shape, self.n_actions), dtype=torch.bool, device=self.device) + backward_masks[..., self.tokenizer.eos_token_id] = False + states.backward_masks = backward_masks + + # --------------------------------------------------------------------- + # Transitions + # --------------------------------------------------------------------- + def step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: + # Insert the new token at the first padding position (keeps shape constant). + new_states_tensor = states.tensor.clone() + pad_token_id = self.tokenizer.pad_token_id + for idx in range(len(states)): + pad_positions = (new_states_tensor[idx] == pad_token_id).nonzero(as_tuple=False) + if pad_positions.numel() == 0: + continue # already full, should not happen thanks to masks + first_pad = pad_positions[0].item() + new_states_tensor[idx, first_pad] = actions.tensor[idx] + out = self.States(new_states_tensor) + self.update_masks(cast(DiscreteStates, out)) + return cast(DiscreteStates, out) + + def backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: + # Remove the last token (it should match ``actions``). + pad_token_id = self.tokenizer.pad_token_id + new_states_tensor = states.tensor.clone() + for idx in range(len(states)): + non_pad_positions = (new_states_tensor[idx] != pad_token_id).nonzero(as_tuple=False) + if non_pad_positions.numel() == 0: + continue + last_idx = non_pad_positions[-1].item() + assert new_states_tensor[idx, last_idx] == actions.tensor[idx] + new_states_tensor[idx, last_idx] = pad_token_id + out = self.States(new_states_tensor) + self.update_masks(cast(DiscreteStates, out)) + return cast(DiscreteStates, out) + + # --------------------------------------------------------------------- + # Reward + # --------------------------------------------------------------------- def log_reward(self, states: DiscreteStates) -> torch.Tensor: - rewards = torch.full((len(states),), -20.0, device=self.device) - + """Uniform log-probability for numbers 0-100; −∞ elsewhere.""" + rewards = torch.full((len(states),), float("-inf"), device=self.device) + prompt_len = self.prompt_tokens.shape[1] - - sequences_to_decode = [] - valid_indices = [] - for i in range(len(states)): - if states.masks[i].sum() > prompt_len: - seq = states.tensor[i][states.masks[i]] - # Only reward terminal states - if seq[-1] == self.tokenizer.eos_token_id: - sequences_to_decode.append(seq[prompt_len:-1]) - valid_indices.append(i) - - if not sequences_to_decode: - return rewards - - decoded_texts = self.tokenizer.batch_decode(sequences_to_decode, skip_special_tokens=True) - - for i, decoded_text in enumerate(decoded_texts): - original_index = valid_indices[i] + + for idx in range(len(states)): + # Identify generated part (after prompt) ignoring padding and eos. + seq = states.tensor[idx] + # Determine where padding starts. + pad_mask = seq == self.tokenizer.pad_token_id + if pad_mask.all(): + continue # empty + + # Extract tokens between prompt and eos. + # First generated index after prompt. + gen_start = prompt_len + # Find eos position. + try: + eos_pos = (seq == self.tokenizer.eos_token_id).nonzero(as_tuple=False)[0].item() + except IndexError: + continue # not terminated yet + + # Only consider when eos is present and sequence beyond prompt. + if eos_pos <= gen_start: + continue + + generated_tokens = seq[gen_start:eos_pos] + if len(generated_tokens) == 0: + continue + + decoded_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) try: number = int(decoded_text.strip()) if 0 <= number <= 100: - rewards[original_index] = torch.log(torch.tensor(1.0 / 101.0, device=self.device)) - except (ValueError, IndexError): + rewards[idx] = torch.log(torch.tensor(1.0 / 101.0, device=self.device)) + except ValueError: pass - + return rewards - def reset(self, batch_size: int = 1) -> DiscreteStates: - return DiscreteStates([self.s0.tensor for _ in range(batch_size)]) -# 2. GFN Module (wrapping the LLM) +class PassThroughPreprocessor(Preprocessor): + """Returns the raw tensor representation of states (no preprocessing).""" + + def __init__(self, output_dim: int): + super().__init__(output_dim=output_dim) + + def preprocess(self, states): # type: ignore[override] + return states.tensor + + class LLMGFNModule(GFNModule): - def __init__(self, model, tokenizer): - super().__init__() + """GFNModule wrapping a pretrained LLM to act as a policy.""" + + def __init__(self, model, tokenizer, state_dim: int, is_backward: bool = False): + super().__init__(module=model, preprocessor=PassThroughPreprocessor(output_dim=state_dim), is_backward=is_backward) self.model = model self.tokenizer = tokenizer - self.log_z = nn.Parameter(torch.tensor(0.0)) - def forward(self, states: States): + # ------------------------------------------------------------------ + # GFNModule interface + # ------------------------------------------------------------------ + @property + def expected_output_dim(self): + return self.tokenizer.vocab_size + + def forward(self, states: DiscreteStates): # type: ignore[override] input_ids = states.tensor - attention_mask = states.masks + # Build an attention mask: 1 for non-pad. + attention_mask = (input_ids != self.tokenizer.pad_token_id).long() + outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) - - sequence_lengths = attention_mask.sum(dim=1) - last_token_logits = outputs.logits[torch.arange(len(outputs.logits)), sequence_lengths - 1, :] - - return last_token_logits, self.log_z -# 3. Training Setup + # Select logits corresponding to the last *non-pad* token. + seq_lengths = attention_mask.sum(dim=1) # (batch,) + last_token_logits = outputs.logits[torch.arange(len(outputs.logits)), seq_lengths - 1, :] + + return last_token_logits + + def to_probability_distribution( + self, + states: DiscreteStates, + module_output: torch.Tensor, + temperature: float = 1.0, + epsilon: float = 0.0, + sf_bias: float = 0.0, + **kwargs, + ): + """Convert raw logits to a categorical distribution, respecting masks.""" + + masks = states.backward_masks if self.is_backward else states.forward_masks + logits = module_output.clone() + + # Apply masks by setting invalid logits to −inf. + logits[~masks] = -float("inf") + + if not self.is_backward and sf_bias != 0.0: + logits[:, -1] -= sf_bias # usually bias exit action + + if temperature != 1.0: + logits = logits / temperature + + probs = torch.softmax(logits, dim=-1) + + if epsilon != 0.0: + uniform = torch.where( + masks.sum(dim=-1, keepdim=True) == 0, + torch.zeros_like(masks, dtype=probs.dtype), + masks.float() / masks.sum(dim=-1, keepdim=True), + ) + probs = (1 - epsilon) * probs + epsilon * uniform + + return torch.distributions.Categorical(probs=probs) + + def main(): device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Using device: {device}") @@ -108,35 +257,47 @@ def main(): prompt = "The following is a random integer drawn uniformly between 0 and 100: " env = RNGEnv(tokenizer, prompt, device=device, max_length=5) - gfn_module = LLMGFNModule(model, tokenizer) - gflownet = GFlowNet(gfn_module, loss=SubTrajectoryBalance()) - sampler = Sampler(gfn_module, env) - optimizer = torch.optim.Adam(gfn_module.parameters(), lr=1e-4) + state_dim = env.state_shape[0] + pf_module = LLMGFNModule(model, tokenizer, state_dim, is_backward=False) + pb_module = LLMGFNModule(model, tokenizer, state_dim, is_backward=True) + + gflownet = TBGFlowNet(pf_module, pb_module) + + optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-4) + + # ------------------------------------------------------------------ + # Training Loop + # ------------------------------------------------------------------ + for step in range(501): # quick demo run + trajectories = gflownet.sample_trajectories(env, n=8, save_logprobs=False) + loss = gflownet.loss(env, trajectories) - print("Starting training...") - # 4. Training Loop - for i in range(501): # A short training loop for demonstration - trajectories = sampler.sample_trajectories(n_trajectories=8) - loss = gflownet.loss(trajectories) - optimizer.zero_grad() loss.backward() optimizer.step() - - if i % 100 == 0: - print(f"Step {i}, Loss: {loss.item()}") + + if step % 100 == 0: + print(f"Step {step}: loss = {loss.item():.4f}") # 5. Evaluation print("\n--- Evaluation ---") with torch.no_grad(): - trajectories = sampler.sample_trajectories(n_trajectories=100) + trajectories = gflownet.sample_trajectories(env, n=100) final_states = trajectories.last_states numbers = [] for i in range(len(final_states)): - state_tensor = final_states.tensor[i][final_states.masks[i]] - generated_tokens = state_tensor[len(env.prompt_tokens[0]):-1] # Exclude prompt and EOS + state_tensor = final_states.tensor[i] + # Identify non-pad tokens + non_pad = state_tensor != tokenizer.pad_token_id + seq = state_tensor[non_pad] + + # Remove prompt and EOS + prompt_len = env.prompt_tokens.shape[1] + if len(seq) <= prompt_len + 1: + continue + generated_tokens = seq[prompt_len:-1] # exclude prompt and eos decoded_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) try: numbers.append(int(decoded_text.strip())) From 81c8543028118acb1ff90c1e7b440fa626add803 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 13 Jul 2025 20:27:20 -0700 Subject: [PATCH 03/11] inf loss --- tutorials/examples/train_rng_gfn.py | 176 +++++++++++++++++++++------- 1 file changed, 131 insertions(+), 45 deletions(-) diff --git a/tutorials/examples/train_rng_gfn.py b/tutorials/examples/train_rng_gfn.py index 3f288e09..1b7bb817 100644 --- a/tutorials/examples/train_rng_gfn.py +++ b/tutorials/examples/train_rng_gfn.py @@ -1,4 +1,17 @@ +""" +Tutorial: Training a GFlowNet to finetune an LLM for random number generation. + +This tutorial demonstrates how to use TorchGFN to finetune a language model (e.g., GPT-2) +to generate random integers between 0 and 100. The GFlowNet learns to sample from a +uniform distribution over these numbers by using trajectory balance training. + +Usage: + python train_rng_gfn.py --help + python train_rng_gfn.py --n_steps 1000 --batch_size 16 + python train_rng_gfn.py --model_name distilgpt2 --device cpu +""" + import torch import torch.nn as nn from typing import cast @@ -10,7 +23,8 @@ from gfn.states import DiscreteStates from gfn.modules import GFNModule from gfn.preprocessors import Preprocessor -from gfn.gflownet import TBGFlowNet +from gfn.gflownet.trajectory_balance import TBGFlowNet +from gfn.samplers import Sampler class RNGEnv(DiscreteEnv): @@ -37,8 +51,6 @@ def __init__(self, tokenizer, prompt, max_length: int = 5, device: str | torch.d self.max_length = max_length self.total_length = prompt_len + max_length # total length of the tensor - self.state_shape = self.total_length # keep name for clarity before tuple wrap - # Initial state: prompt followed by padding. s0 = torch.nn.functional.pad( self.prompt_tokens.squeeze(0), @@ -53,15 +65,12 @@ def __init__(self, tokenizer, prompt, max_length: int = 5, device: str | torch.d value=tokenizer.pad_token_id, ) - # We'll treat actions as scalars (shape = ()), so configure accordingly. + # Use default action handling like HyperGrid super().__init__( - s0=s0, - sf=sf, n_actions=tokenizer.vocab_size, + s0=s0, state_shape=(self.total_length,), - action_shape=(), - dummy_action=torch.tensor(-1, device=device), - exit_action=torch.tensor(tokenizer.eos_token_id, device=device), + sf=sf, ) # --------------------------------------------------------------------- @@ -80,6 +89,14 @@ def _forward_action_masks(self, states: DiscreteStates) -> torch.Tensor: if full_idx.any(): masks[full_idx] = False masks[full_idx, self.tokenizer.eos_token_id] = True + + # Debug: make sure at least some actions are valid + valid_actions = masks.sum(dim=1) + if torch.any(valid_actions == 0): + print(f"Warning: Some states have no valid actions. seq_lens: {seq_lens}, total_length: {self.total_length}") + # Allow at least EOS token + invalid_states = valid_actions == 0 + masks[invalid_states, self.tokenizer.eos_token_id] = True return masks @@ -88,11 +105,12 @@ def update_masks(self, states: DiscreteStates) -> None: # type: ignore[override # Forward masks (valid next tokens). states.forward_masks = self._forward_action_masks(states) - - # Backward masks exclude the exit action (EOS). - backward_masks = torch.ones((*states.batch_shape, self.n_actions), dtype=torch.bool, device=self.device) - backward_masks[..., self.tokenizer.eos_token_id] = False - states.backward_masks = backward_masks + + # Backward masks: can go backward unless we're at the initial state + prompt_len = self.prompt_tokens.shape[1] + seq_lens = (states.tensor != self.tokenizer.pad_token_id).sum(dim=1) + at_initial = seq_lens <= prompt_len + states.backward_masks = ~at_initial.unsqueeze(-1).expand(-1, self.n_actions - 1) # --------------------------------------------------------------------- # Transitions @@ -106,7 +124,7 @@ def step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: if pad_positions.numel() == 0: continue # already full, should not happen thanks to masks first_pad = pad_positions[0].item() - new_states_tensor[idx, first_pad] = actions.tensor[idx] + new_states_tensor[idx, first_pad] = actions.tensor[idx, 0] out = self.States(new_states_tensor) self.update_masks(cast(DiscreteStates, out)) return cast(DiscreteStates, out) @@ -120,7 +138,7 @@ def backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteSta if non_pad_positions.numel() == 0: continue last_idx = non_pad_positions[-1].item() - assert new_states_tensor[idx, last_idx] == actions.tensor[idx] + assert new_states_tensor[idx, last_idx] == actions.tensor[idx, 0] new_states_tensor[idx, last_idx] = pad_token_id out = self.States(new_states_tensor) self.update_masks(cast(DiscreteStates, out)) @@ -220,11 +238,38 @@ def to_probability_distribution( ): """Convert raw logits to a categorical distribution, respecting masks.""" - masks = states.backward_masks if self.is_backward else states.forward_masks logits = module_output.clone() + + if self.is_backward: + # Backward masks exclude the exit action (last action) + masks = states.backward_masks + # Apply masks to all actions except the last one (exit action) + logits[:, :-1][~masks] = -float("inf") + else: + # Forward masks include all actions + masks = states.forward_masks + logits[~masks] = -float("inf") - # Apply masks by setting invalid logits to −inf. - logits[~masks] = -float("inf") + # Check for any completely invalid states + if self.is_backward: + valid_mask_counts = masks.sum(dim=-1) + else: + valid_mask_counts = masks.sum(dim=-1) + + if torch.any(valid_mask_counts == 0): + print(f"Warning: Found states with no valid actions in {'backward' if self.is_backward else 'forward'} mode") + print(f"Valid action counts: {valid_mask_counts}") + # Force at least one action to be valid (EOS for forward, or some action for backward) + if self.is_backward: + invalid_idx = valid_mask_counts == 0 + if torch.any(invalid_idx): + # For backward, allow the first non-exit action + masks[invalid_idx, 0] = True + else: + invalid_idx = valid_mask_counts == 0 + if torch.any(invalid_idx): + # For forward, allow EOS + logits[invalid_idx, self.tokenizer.eos_token_id] = 0.0 if not self.is_backward and sf_bias != 0.0: logits[:, -1] -= sf_bias # usually bias exit action @@ -233,44 +278,80 @@ def to_probability_distribution( logits = logits / temperature probs = torch.softmax(logits, dim=-1) + + # Ensure probabilities are valid + probs = torch.clamp(probs, min=1e-8) + probs = probs / probs.sum(dim=-1, keepdim=True) - if epsilon != 0.0: - uniform = torch.where( - masks.sum(dim=-1, keepdim=True) == 0, - torch.zeros_like(masks, dtype=probs.dtype), - masks.float() / masks.sum(dim=-1, keepdim=True), - ) - probs = (1 - epsilon) * probs + epsilon * uniform - - return torch.distributions.Categorical(probs=probs) + # Create a custom distribution that samples with the right shape + categorical = torch.distributions.Categorical(probs=probs) + + class ShapedCategorical: + def __init__(self, cat_dist): + self.cat_dist = cat_dist + + def sample(self): + # Sample from categorical and reshape to (batch_size, 1) + samples = self.cat_dist.sample() + return samples.unsqueeze(-1) + + def log_prob(self, value): + # value should have shape (batch_size, 1), squeeze for categorical + if value.dim() > 1: + value = value.squeeze(-1) + return self.cat_dist.log_prob(value) + + return ShapedCategorical(categorical) def main(): - device = 'cuda' if torch.cuda.is_available() else 'cpu' + import argparse + + parser = argparse.ArgumentParser(description="Train GFlowNet to generate random numbers 0-100") + parser.add_argument("--device", default="auto", help="Device to use (auto, cpu, cuda)") + parser.add_argument("--model_name", default="gpt2", help="Model name from HuggingFace") + parser.add_argument("--n_steps", type=int, default=500, help="Number of training steps") + parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training") + parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") + parser.add_argument("--max_length", type=int, default=5, help="Max tokens to generate") + parser.add_argument("--eval_samples", type=int, default=100, help="Number of samples for evaluation") + + args = parser.parse_args() + + # Device setup + if args.device == "auto": + device = 'cuda' if torch.cuda.is_available() else 'cpu' + else: + device = args.device print(f"Using device: {device}") - tokenizer = AutoTokenizer.from_pretrained('gpt2') + # Model and tokenizer setup + print(f"Loading model: {args.model_name}") + tokenizer = AutoTokenizer.from_pretrained(args.model_name) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained('gpt2').to(device) + model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device) + # Environment setup prompt = "The following is a random integer drawn uniformly between 0 and 100: " - env = RNGEnv(tokenizer, prompt, device=device, max_length=5) + env = RNGEnv(tokenizer, prompt, device=device, max_length=args.max_length) + print(f"Environment set up with prompt: '{prompt}'") - state_dim = env.state_shape[0] + # GFlowNet setup + state_dim = env.total_length pf_module = LLMGFNModule(model, tokenizer, state_dim, is_backward=False) pb_module = LLMGFNModule(model, tokenizer, state_dim, is_backward=True) gflownet = TBGFlowNet(pf_module, pb_module) + sampler = Sampler(pf_module) - optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-4) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=args.lr) + print(f"Training for {args.n_steps} steps with batch size {args.batch_size}") - # ------------------------------------------------------------------ # Training Loop - # ------------------------------------------------------------------ - for step in range(501): # quick demo run - trajectories = gflownet.sample_trajectories(env, n=8, save_logprobs=False) + for step in range(args.n_steps): + trajectories = sampler.sample_trajectories(env, n=args.batch_size, save_logprobs=False) loss = gflownet.loss(env, trajectories) optimizer.zero_grad() @@ -278,12 +359,13 @@ def main(): optimizer.step() if step % 100 == 0: - print(f"Step {step}: loss = {loss.item():.4f}") + print(f"Step {step:4d}: loss = {loss.item():.4f}") - # 5. Evaluation + # Evaluation print("\n--- Evaluation ---") + print(f"Sampling {args.eval_samples} trajectories for evaluation...") with torch.no_grad(): - trajectories = gflownet.sample_trajectories(env, n=100) + trajectories = sampler.sample_trajectories(env, n=args.eval_samples) final_states = trajectories.last_states numbers = [] @@ -300,15 +382,19 @@ def main(): generated_tokens = seq[prompt_len:-1] # exclude prompt and eos decoded_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) try: - numbers.append(int(decoded_text.strip())) + number = int(decoded_text.strip()) + if 0 <= number <= 100: # Only count valid numbers in range + numbers.append(number) except (ValueError, IndexError): pass - print(f"Generated numbers: {numbers}") + print(f"Generated valid numbers: {numbers[:20]}...") # Show first 20 if numbers: counts = np.bincount(numbers, minlength=101) - print(f"Counts of numbers 0-100: {counts}") - print(f"Number of valid samples: {len(numbers)}") + valid_counts = counts[:101] # Only 0-100 + print(f"Number of valid samples: {len(numbers)} / {args.eval_samples} ({100*len(numbers)/args.eval_samples:.1f}%)") + print(f"Unique numbers generated: {np.count_nonzero(valid_counts)} / 101") + print(f"Mean: {np.mean(numbers):.1f}, Std: {np.std(numbers):.1f}") else: print("No valid numbers generated.") From 6859e7cd4578ffd1bfd970a8619dab72613c34e6 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 20 Jul 2025 04:47:01 -0700 Subject: [PATCH 04/11] fixes --- tutorials/examples/train_rng_gfn.py | 86 ++++++++++++++--------------- 1 file changed, 41 insertions(+), 45 deletions(-) diff --git a/tutorials/examples/train_rng_gfn.py b/tutorials/examples/train_rng_gfn.py index 1b7bb817..341bd8d2 100644 --- a/tutorials/examples/train_rng_gfn.py +++ b/tutorials/examples/train_rng_gfn.py @@ -7,13 +7,11 @@ uniform distribution over these numbers by using trajectory balance training. Usage: - python train_rng_gfn.py --help python train_rng_gfn.py --n_steps 1000 --batch_size 16 python train_rng_gfn.py --model_name distilgpt2 --device cpu """ import torch -import torch.nn as nn from typing import cast from transformers import AutoModelForCausalLM, AutoTokenizer import numpy as np @@ -21,8 +19,8 @@ from gfn.env import DiscreteEnv from gfn.actions import Actions from gfn.states import DiscreteStates -from gfn.modules import GFNModule -from gfn.preprocessors import Preprocessor +from gfn.modules import DiscretePolicyEstimator, GFNModule +from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.gflownet.trajectory_balance import TBGFlowNet from gfn.samplers import Sampler @@ -90,27 +88,36 @@ def _forward_action_masks(self, states: DiscreteStates) -> torch.Tensor: masks[full_idx] = False masks[full_idx, self.tokenizer.eos_token_id] = True - # Debug: make sure at least some actions are valid - valid_actions = masks.sum(dim=1) - if torch.any(valid_actions == 0): - print(f"Warning: Some states have no valid actions. seq_lens: {seq_lens}, total_length: {self.total_length}") - # Allow at least EOS token - invalid_states = valid_actions == 0 - masks[invalid_states, self.tokenizer.eos_token_id] = True - return masks + + def _backward_action_masks(self, states: DiscreteStates) -> torch.Tensor: + """Returns a bool mask (*batch, n_actions) of valid *backward* actions.""" + prompt_len = self.prompt_tokens.shape[1] + seq_lens = (states.tensor != self.tokenizer.pad_token_id).sum(dim=1) + at_initial = seq_lens <= prompt_len + + # Backward mask: only the last non-pad token can be removed (one True per row, rest False) + batch_size = states.tensor.shape[0] + backward_masks = torch.zeros((batch_size, self.n_actions - 1), dtype=torch.bool, device=states.tensor.device) + + # Find which sequences can go backward (not at initial state) + can_go_back = ~at_initial + if can_go_back.any(): + # Get the last token ID for each sequence that can go backward + last_token_indices = seq_lens[can_go_back] - 1 + batch_indices = torch.arange(0, batch_size, device=states.tensor.device)[can_go_back] + last_token_ids = states.tensor[batch_indices, last_token_indices] + + assert (last_token_ids < self.n_actions - 1).all() + backward_masks[batch_indices, last_token_ids] = True + return backward_masks def update_masks(self, states: DiscreteStates) -> None: # type: ignore[override] """Populate ``states.forward_masks`` and ``states.backward_masks`` in-place.""" - # Forward masks (valid next tokens). states.forward_masks = self._forward_action_masks(states) - # Backward masks: can go backward unless we're at the initial state - prompt_len = self.prompt_tokens.shape[1] - seq_lens = (states.tensor != self.tokenizer.pad_token_id).sum(dim=1) - at_initial = seq_lens <= prompt_len - states.backward_masks = ~at_initial.unsqueeze(-1).expand(-1, self.n_actions - 1) + states.backward_masks = self._backward_action_masks(states) # --------------------------------------------------------------------- # Transitions @@ -149,10 +156,9 @@ def backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteSta # --------------------------------------------------------------------- def log_reward(self, states: DiscreteStates) -> torch.Tensor: """Uniform log-probability for numbers 0-100; −∞ elsewhere.""" - rewards = torch.full((len(states),), float("-inf"), device=self.device) + rewards = torch.full((len(states),), -1e4, device=self.device) prompt_len = self.prompt_tokens.shape[1] - for idx in range(len(states)): # Identify generated part (after prompt) ignoring padding and eos. seq = states.tensor[idx] @@ -165,15 +171,10 @@ def log_reward(self, states: DiscreteStates) -> torch.Tensor: # First generated index after prompt. gen_start = prompt_len # Find eos position. - try: - eos_pos = (seq == self.tokenizer.eos_token_id).nonzero(as_tuple=False)[0].item() - except IndexError: + eos_positions = (seq == self.tokenizer.eos_token_id).nonzero(as_tuple=False) + if eos_positions.numel() == 0: continue # not terminated yet - - # Only consider when eos is present and sequence beyond prompt. - if eos_pos <= gen_start: - continue - + eos_pos = eos_positions[0].item() generated_tokens = seq[gen_start:eos_pos] if len(generated_tokens) == 0: continue @@ -182,28 +183,23 @@ def log_reward(self, states: DiscreteStates) -> torch.Tensor: try: number = int(decoded_text.strip()) if 0 <= number <= 100: - rewards[idx] = torch.log(torch.tensor(1.0 / 101.0, device=self.device)) + rewards[idx] = 0.0 except ValueError: pass return rewards -class PassThroughPreprocessor(Preprocessor): - """Returns the raw tensor representation of states (no preprocessing).""" - - def __init__(self, output_dim: int): - super().__init__(output_dim=output_dim) - - def preprocess(self, states): # type: ignore[override] - return states.tensor - - -class LLMGFNModule(GFNModule): +class LLMGFNModule(DiscretePolicyEstimator): """GFNModule wrapping a pretrained LLM to act as a policy.""" def __init__(self, model, tokenizer, state_dim: int, is_backward: bool = False): - super().__init__(module=model, preprocessor=PassThroughPreprocessor(output_dim=state_dim), is_backward=is_backward) + super().__init__( + module=model, + n_actions=tokenizer.vocab_size, + preprocessor=IdentityPreprocessor(state_dim), + is_backward=is_backward + ) self.model = model self.tokenizer = tokenizer @@ -310,8 +306,8 @@ def main(): parser = argparse.ArgumentParser(description="Train GFlowNet to generate random numbers 0-100") parser.add_argument("--device", default="auto", help="Device to use (auto, cpu, cuda)") parser.add_argument("--model_name", default="gpt2", help="Model name from HuggingFace") - parser.add_argument("--n_steps", type=int, default=500, help="Number of training steps") - parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training") + parser.add_argument("--n_steps", type=int, default=50, help="Number of training steps") + parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") parser.add_argument("--max_length", type=int, default=5, help="Max tokens to generate") parser.add_argument("--eval_samples", type=int, default=100, help="Number of samples for evaluation") @@ -358,7 +354,7 @@ def main(): loss.backward() optimizer.step() - if step % 100 == 0: + if step % 1 == 0: print(f"Step {step:4d}: loss = {loss.item():.4f}") # Evaluation @@ -366,7 +362,7 @@ def main(): print(f"Sampling {args.eval_samples} trajectories for evaluation...") with torch.no_grad(): trajectories = sampler.sample_trajectories(env, n=args.eval_samples) - final_states = trajectories.last_states + final_states = trajectories.terminating_states numbers = [] for i in range(len(final_states)): From c1516d2d7056a0781fd2fcef8f352e280a5aaa94 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 21 Jul 2025 18:48:54 +0200 Subject: [PATCH 05/11] add LoRa --- tutorials/examples/train_rng_gfn.py | 52 ++++++++++++++++++++++++++--- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/tutorials/examples/train_rng_gfn.py b/tutorials/examples/train_rng_gfn.py index 341bd8d2..596f5f93 100644 --- a/tutorials/examples/train_rng_gfn.py +++ b/tutorials/examples/train_rng_gfn.py @@ -4,17 +4,25 @@ This tutorial demonstrates how to use TorchGFN to finetune a language model (e.g., GPT-2) to generate random integers between 0 and 100. The GFlowNet learns to sample from a -uniform distribution over these numbers by using trajectory balance training. +uniform distribution over these numbers by using trajectory balance training. + +Supports both parameter-efficient fine-tuning with LoRA (default) and full fine-tuning. Usage: + # LoRA training (default) python train_rng_gfn.py --n_steps 1000 --batch_size 16 - python train_rng_gfn.py --model_name distilgpt2 --device cpu + + # Custom LoRA configuration + python train_rng_gfn.py --lora_r 16 --use_lora --lora_alpha 32 --target_modules c_attn c_proj """ import torch from typing import cast from transformers import AutoModelForCausalLM, AutoTokenizer import numpy as np +from peft.tuners.lora import LoraConfig +from peft.mapping import get_peft_model +from peft.utils.peft_types import TaskType from gfn.env import DiscreteEnv from gfn.actions import Actions @@ -307,11 +315,19 @@ def main(): parser.add_argument("--device", default="auto", help="Device to use (auto, cpu, cuda)") parser.add_argument("--model_name", default="gpt2", help="Model name from HuggingFace") parser.add_argument("--n_steps", type=int, default=50, help="Number of training steps") - parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training") + parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") parser.add_argument("--max_length", type=int, default=5, help="Max tokens to generate") parser.add_argument("--eval_samples", type=int, default=100, help="Number of samples for evaluation") + # LoRA-specific arguments + parser.add_argument("--use_lora", action="store_true", default=False, help="Use LoRA for parameter-efficient fine-tuning") + parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank") + parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha scaling parameter") + parser.add_argument("--lora_dropout", type=float, default=0.1, help="LoRA dropout probability") + parser.add_argument("--target_modules", nargs="+", default=["c_attn", "c_proj"], + help="Target modules for LoRA adaptation (default for GPT-2)") + args = parser.parse_args() # Device setup @@ -327,7 +343,30 @@ def main(): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token + # Load base model model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device) + + if args.use_lora: + # Configure LoRA + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + r=args.lora_r, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + target_modules=args.target_modules, + bias="none", + ) + + # Apply LoRA to the model + model = get_peft_model(model, lora_config) + model.print_trainable_parameters() + print(f"Applied LoRA with rank={args.lora_r}, alpha={args.lora_alpha}, target_modules={args.target_modules}") + else: + print("Using full fine-tuning (LoRA disabled)") + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"Total parameters: {total_params:,}") + print(f"Trainable parameters: {trainable_params:,} ({100 * trainable_params / total_params:.2f}%)") # Environment setup prompt = "The following is a random integer drawn uniformly between 0 and 100: " @@ -342,8 +381,13 @@ def main(): gflownet = TBGFlowNet(pf_module, pb_module) sampler = Sampler(pf_module) - optimizer = torch.optim.Adam(gflownet.parameters(), lr=args.lr) + # Set up optimizer for trainable parameters + trainable_params = [p for p in gflownet.parameters() if p.requires_grad] + optimizer = torch.optim.Adam(trainable_params, lr=args.lr) + + param_count = sum(p.numel() for p in trainable_params) print(f"Training for {args.n_steps} steps with batch size {args.batch_size}") + print(f"Optimizing {param_count:,} trainable parameters across {len(trainable_params)} parameter groups") # Training Loop for step in range(args.n_steps): From 2a3ca0aac803b435629e7cc10670844780bbae6f Mon Sep 17 00:00:00 2001 From: younik Date: Tue, 22 Jul 2025 12:29:36 -0400 Subject: [PATCH 06/11] improve optimizer --- tutorials/examples/train_rng_gfn.py | 255 ++++++++++++++++++++-------- 1 file changed, 185 insertions(+), 70 deletions(-) diff --git a/tutorials/examples/train_rng_gfn.py b/tutorials/examples/train_rng_gfn.py index 596f5f93..3912b044 100644 --- a/tutorials/examples/train_rng_gfn.py +++ b/tutorials/examples/train_rng_gfn.py @@ -6,14 +6,22 @@ to generate random integers between 0 and 100. The GFlowNet learns to sample from a uniform distribution over these numbers by using trajectory balance training. -Supports both parameter-efficient fine-tuning with LoRA (default) and full fine-tuning. +Features: +- Supports both parameter-efficient fine-tuning with LoRA and full fine-tuning +- Uses AdamW optimizer with weight decay (standard for transformer models) +- Configurable learning rate scheduling (cosine, linear, or constant) +- Gradient clipping for training stability +- Warmup steps for better convergence Usage: - # LoRA training (default) - python train_rng_gfn.py --n_steps 1000 --batch_size 16 + # LoRA training with cosine scheduler (recommended) + python train_rng_gfn.py --use_lora --n_steps 1000 --batch_size 16 --warmup_steps 100 - # Custom LoRA configuration - python train_rng_gfn.py --lora_r 16 --use_lora --lora_alpha 32 --target_modules c_attn c_proj + # Full fine-tuning with linear scheduler + python train_rng_gfn.py --n_steps 500 --scheduler_type linear --lr 1e-5 --weight_decay 0.01 + + # Custom LoRA configuration with different scheduler + python train_rng_gfn.py --lora_r 16 --use_lora --lora_alpha 32 --target_modules c_attn c_proj --scheduler_type constant """ import torch @@ -21,8 +29,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import numpy as np from peft.tuners.lora import LoraConfig -from peft.mapping import get_peft_model +from peft import get_peft_model from peft.utils.peft_types import TaskType +from transformers import get_cosine_schedule_with_warmup, get_linear_schedule_with_warmup from gfn.env import DiscreteEnv from gfn.actions import Actions @@ -79,13 +88,19 @@ def __init__(self, tokenizer, prompt, max_length: int = 5, device: str | torch.d sf=sf, ) + @property + def device(self) -> torch.device: + """Returns the device of the environment.""" + return self.s0.device + # --------------------------------------------------------------------- # Masks helpers # --------------------------------------------------------------------- def _forward_action_masks(self, states: DiscreteStates) -> torch.Tensor: """Returns a bool mask (*batch, n_actions) of valid *forward* actions.""" batch_size = states.batch_shape[0] - masks = torch.ones((batch_size, self.n_actions), dtype=torch.bool, device=self.device) + # Use the device from the states tensor instead of self.device + masks = torch.ones((batch_size, self.n_actions), dtype=torch.bool, device=states.tensor.device) # Current sequence length (non-pad tokens). seq_lens = (states.tensor != self.tokenizer.pad_token_id).sum(dim=1) @@ -113,11 +128,15 @@ def _backward_action_masks(self, states: DiscreteStates) -> torch.Tensor: if can_go_back.any(): # Get the last token ID for each sequence that can go backward last_token_indices = seq_lens[can_go_back] - 1 - batch_indices = torch.arange(0, batch_size, device=states.tensor.device)[can_go_back] + batch_indices = torch.arange(batch_size, device=states.tensor.device)[can_go_back] last_token_ids = states.tensor[batch_indices, last_token_indices] - assert (last_token_ids < self.n_actions - 1).all() - backward_masks[batch_indices, last_token_ids] = True + # Ensure token IDs are valid for backward actions (exclude exit action) + valid_token_mask = last_token_ids < self.n_actions - 1 + if valid_token_mask.any(): + valid_batch_indices = batch_indices[valid_token_mask] + valid_token_ids = last_token_ids[valid_token_mask] + backward_masks[valid_batch_indices, valid_token_ids] = True return backward_masks def update_masks(self, states: DiscreteStates) -> None: # type: ignore[override] @@ -164,7 +183,7 @@ def backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteSta # --------------------------------------------------------------------- def log_reward(self, states: DiscreteStates) -> torch.Tensor: """Uniform log-probability for numbers 0-100; −∞ elsewhere.""" - rewards = torch.full((len(states),), -1e4, device=self.device) + rewards = torch.full((len(states),), -1e2, device=self.device) prompt_len = self.prompt_tokens.shape[1] for idx in range(len(states)): @@ -249,31 +268,31 @@ def to_probability_distribution( masks = states.backward_masks # Apply masks to all actions except the last one (exit action) logits[:, :-1][~masks] = -float("inf") + # Always mask out the exit action for backward steps + logits[:, -1] = -float("inf") else: # Forward masks include all actions masks = states.forward_masks logits[~masks] = -float("inf") - # Check for any completely invalid states - if self.is_backward: - valid_mask_counts = masks.sum(dim=-1) - else: - valid_mask_counts = masks.sum(dim=-1) + # Check for any completely invalid states and fix them + valid_mask_counts = masks.sum(dim=-1) + invalid_idx = valid_mask_counts == 0 - if torch.any(valid_mask_counts == 0): - print(f"Warning: Found states with no valid actions in {'backward' if self.is_backward else 'forward'} mode") - print(f"Valid action counts: {valid_mask_counts}") - # Force at least one action to be valid (EOS for forward, or some action for backward) + if torch.any(invalid_idx): + print(f"Warning: Found {invalid_idx.sum().item()} states with no valid actions in {'backward' if self.is_backward else 'forward'} mode") + # Force at least one action to be valid if self.is_backward: - invalid_idx = valid_mask_counts == 0 - if torch.any(invalid_idx): - # For backward, allow the first non-exit action - masks[invalid_idx, 0] = True + # For backward, allow the first non-exit action (typically action 0) + masks[invalid_idx, 0] = True + logits[invalid_idx, :-1] = -float("inf") # mask out all non-exit actions + logits[invalid_idx, 0] = 0.0 # set action 0 to be valid + logits[invalid_idx, -1] = -float("inf") # ensure exit action remains masked else: - invalid_idx = valid_mask_counts == 0 - if torch.any(invalid_idx): - # For forward, allow EOS - logits[invalid_idx, self.tokenizer.eos_token_id] = 0.0 + # For forward, allow EOS token + masks[invalid_idx, self.tokenizer.eos_token_id] = True + logits[invalid_idx, :] = -float("inf") # mask out all actions + logits[invalid_idx, self.tokenizer.eos_token_id] = 0.0 # allow EOS if not self.is_backward and sf_bias != 0.0: logits[:, -1] -= sf_bias # usually bias exit action @@ -308,26 +327,99 @@ def log_prob(self, value): return ShapedCategorical(categorical) +def evaluate_model(env, sampler, tokenizer, n_samples=100, step_name="Evaluation"): + """Evaluate the model by sampling trajectories and analyzing generated numbers.""" + print(f"\n--- {step_name} ---") + print(f"Sampling {n_samples} trajectories for evaluation...") + + with torch.no_grad(): + trajectories = sampler.sample_trajectories(env, n=n_samples) + final_states = trajectories.terminating_states + + numbers = [] + for i in range(len(final_states)): + state_tensor = final_states.tensor[i] + # Identify non-pad tokens + non_pad = state_tensor != tokenizer.pad_token_id + seq = state_tensor[non_pad] + + # Remove prompt and EOS + prompt_len = env.prompt_tokens.shape[1] + if len(seq) <= prompt_len + 1: + continue + generated_tokens = seq[prompt_len:-1] # exclude prompt and eos + decoded_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) + try: + number = int(decoded_text.strip()) + if 0 <= number <= 100: # Only count valid numbers in range + numbers.append(number) + except (ValueError, IndexError): + pass + + print(f"Generated valid numbers: {numbers[:20]}...") # Show first 20 + + results = {} + if numbers: + counts = np.bincount(numbers, minlength=101) + valid_counts = counts[:101] # Only 0-100 + success_rate = len(numbers) / n_samples + unique_numbers = np.count_nonzero(valid_counts) + mean_val = np.mean(numbers) + std_val = np.std(numbers) + + print(f"Number of valid samples: {len(numbers)} / {n_samples} ({100*success_rate:.1f}%)") + print(f"Unique numbers generated: {unique_numbers} / 101") + print(f"Mean: {mean_val:.1f}, Std: {std_val:.1f}") + + results = { + 'valid_numbers': numbers, + 'success_rate': success_rate, + 'unique_count': unique_numbers, + 'mean': mean_val, + 'std': std_val, + 'total_valid': len(numbers) + } + else: + print("No valid numbers generated.") + results = { + 'valid_numbers': [], + 'success_rate': 0.0, + 'unique_count': 0, + 'mean': 0.0, + 'std': 0.0, + 'total_valid': 0 + } + + return results + + def main(): import argparse parser = argparse.ArgumentParser(description="Train GFlowNet to generate random numbers 0-100") parser.add_argument("--device", default="auto", help="Device to use (auto, cpu, cuda)") parser.add_argument("--model_name", default="gpt2", help="Model name from HuggingFace") - parser.add_argument("--n_steps", type=int, default=50, help="Number of training steps") - parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training") + parser.add_argument("--n_steps", type=int, default=200, help="Number of training steps") + parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") - parser.add_argument("--max_length", type=int, default=5, help="Max tokens to generate") + parser.add_argument("--max_length", type=int, default=2, help="Max tokens to generate") parser.add_argument("--eval_samples", type=int, default=100, help="Number of samples for evaluation") # LoRA-specific arguments parser.add_argument("--use_lora", action="store_true", default=False, help="Use LoRA for parameter-efficient fine-tuning") - parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank") + parser.add_argument("--lora_r", type=int, default=32, help="LoRA rank") parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha scaling parameter") parser.add_argument("--lora_dropout", type=float, default=0.1, help="LoRA dropout probability") parser.add_argument("--target_modules", nargs="+", default=["c_attn", "c_proj"], help="Target modules for LoRA adaptation (default for GPT-2)") + # Optimizer and scheduler arguments + parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer") + parser.add_argument("--warmup_steps", type=int, default=100, help="Number of warmup steps for learning rate scheduler") + parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient norm for clipping") + parser.add_argument("--scheduler_type", default="cosine", choices=["cosine", "linear", "constant"], + help="Learning rate scheduler type") + args = parser.parse_args() # Device setup @@ -381,14 +473,45 @@ def main(): gflownet = TBGFlowNet(pf_module, pb_module) sampler = Sampler(pf_module) - # Set up optimizer for trainable parameters + # Set up optimizer and scheduler for trainable parameters trainable_params = [p for p in gflownet.parameters() if p.requires_grad] - optimizer = torch.optim.Adam(trainable_params, lr=args.lr) + + # Use AdamW optimizer (preferred for transformers) + optimizer = torch.optim.AdamW( + trainable_params, + lr=args.lr, + weight_decay=args.weight_decay, + betas=(0.9, 0.999), + eps=1e-8 + ) + + # Set up learning rate scheduler + if args.scheduler_type == "cosine": + scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_warmup_steps=args.warmup_steps, + num_training_steps=args.n_steps + ) + elif args.scheduler_type == "linear": + scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=args.warmup_steps, + num_training_steps=args.n_steps + ) + else: # constant + scheduler = None param_count = sum(p.numel() for p in trainable_params) print(f"Training for {args.n_steps} steps with batch size {args.batch_size}") print(f"Optimizing {param_count:,} trainable parameters across {len(trainable_params)} parameter groups") + print(f"Using AdamW optimizer with lr={args.lr}, weight_decay={args.weight_decay}") + print(f"Learning rate scheduler: {args.scheduler_type}, warmup_steps={args.warmup_steps}") + print(f"Gradient clipping: max_norm={args.max_grad_norm}") + # Pre-training evaluation + print("Evaluating model performance before training...") + pre_training_results = evaluate_model(env, sampler, tokenizer, args.eval_samples, "Pre-training Evaluation") + # Training Loop for step in range(args.n_steps): trajectories = sampler.sample_trajectories(env, n=args.batch_size, save_logprobs=False) @@ -396,47 +519,39 @@ def main(): optimizer.zero_grad() loss.backward() + + # Gradient clipping for stability + if args.max_grad_norm > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) + optimizer.step() + + # Update learning rate scheduler + if scheduler is not None: + scheduler.step() if step % 1 == 0: - print(f"Step {step:4d}: loss = {loss.item():.4f}") + current_lr = optimizer.param_groups[0]['lr'] + if args.max_grad_norm > 0: + print(f"Step {step:4d}: loss = {loss.item():.4f}, lr = {current_lr:.6f}, grad_norm = {grad_norm:.4f}") + else: + print(f"Step {step:4d}: loss = {loss.item():.4f}, lr = {current_lr:.6f}") + + # Post-training evaluation + post_training_results = evaluate_model(env, sampler, tokenizer, args.eval_samples, "Post-training Evaluation") + + # Compare results + print("\n--- Training Results Comparison ---") + print(f"Success Rate: {pre_training_results['success_rate']:.1%} → {post_training_results['success_rate']:.1%} " + f"(Δ: {post_training_results['success_rate'] - pre_training_results['success_rate']:+.1%})") + print(f"Unique Numbers: {pre_training_results['unique_count']} → {post_training_results['unique_count']} " + f"(Δ: {post_training_results['unique_count'] - pre_training_results['unique_count']:+d})") + print(f"Mean Value: {pre_training_results['mean']:.1f} → {post_training_results['mean']:.1f} " + f"(Δ: {post_training_results['mean'] - pre_training_results['mean']:+.1f})") + print(f"Std Deviation: {pre_training_results['std']:.1f} → {post_training_results['std']:.1f} " + f"(Δ: {post_training_results['std'] - pre_training_results['std']:+.1f})") - # Evaluation - print("\n--- Evaluation ---") - print(f"Sampling {args.eval_samples} trajectories for evaluation...") - with torch.no_grad(): - trajectories = sampler.sample_trajectories(env, n=args.eval_samples) - final_states = trajectories.terminating_states - - numbers = [] - for i in range(len(final_states)): - state_tensor = final_states.tensor[i] - # Identify non-pad tokens - non_pad = state_tensor != tokenizer.pad_token_id - seq = state_tensor[non_pad] - # Remove prompt and EOS - prompt_len = env.prompt_tokens.shape[1] - if len(seq) <= prompt_len + 1: - continue - generated_tokens = seq[prompt_len:-1] # exclude prompt and eos - decoded_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) - try: - number = int(decoded_text.strip()) - if 0 <= number <= 100: # Only count valid numbers in range - numbers.append(number) - except (ValueError, IndexError): - pass - - print(f"Generated valid numbers: {numbers[:20]}...") # Show first 20 - if numbers: - counts = np.bincount(numbers, minlength=101) - valid_counts = counts[:101] # Only 0-100 - print(f"Number of valid samples: {len(numbers)} / {args.eval_samples} ({100*len(numbers)/args.eval_samples:.1f}%)") - print(f"Unique numbers generated: {np.count_nonzero(valid_counts)} / 101") - print(f"Mean: {np.mean(numbers):.1f}, Std: {np.std(numbers):.1f}") - else: - print("No valid numbers generated.") if __name__ == '__main__': main() From 54305794eac41abb2a8c090a501f227127402a7f Mon Sep 17 00:00:00 2001 From: younik Date: Sat, 2 Aug 2025 06:09:08 -0400 Subject: [PATCH 07/11] update code --- tutorials/examples/train_rng_gfn.py | 299 ++++++++++++++++------------ 1 file changed, 174 insertions(+), 125 deletions(-) diff --git a/tutorials/examples/train_rng_gfn.py b/tutorials/examples/train_rng_gfn.py index 3912b044..91bbeb2f 100644 --- a/tutorials/examples/train_rng_gfn.py +++ b/tutorials/examples/train_rng_gfn.py @@ -12,6 +12,7 @@ - Configurable learning rate scheduling (cosine, linear, or constant) - Gradient clipping for training stability - Warmup steps for better convergence +- Optional replay buffer for improved training stability Usage: # LoRA training with cosine scheduler (recommended) @@ -20,6 +21,9 @@ # Full fine-tuning with linear scheduler python train_rng_gfn.py --n_steps 500 --scheduler_type linear --lr 1e-5 --weight_decay 0.01 + # Training with replay buffer for better stability + python train_rng_gfn.py --use_lora --use_buffer --n_steps 1000 --batch_size 16 + # Custom LoRA configuration with different scheduler python train_rng_gfn.py --lora_r 16 --use_lora --lora_alpha 32 --target_modules c_attn c_proj --scheduler_type constant """ @@ -28,6 +32,8 @@ from typing import cast from transformers import AutoModelForCausalLM, AutoTokenizer import numpy as np +import matplotlib.pyplot as plt +import os from peft.tuners.lora import LoraConfig from peft import get_peft_model from peft.utils.peft_types import TaskType @@ -40,6 +46,7 @@ from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.gflownet.trajectory_balance import TBGFlowNet from gfn.samplers import Sampler +from gfn.containers import ReplayBuffer class RNGEnv(DiscreteEnv): @@ -121,7 +128,7 @@ def _backward_action_masks(self, states: DiscreteStates) -> torch.Tensor: # Backward mask: only the last non-pad token can be removed (one True per row, rest False) batch_size = states.tensor.shape[0] - backward_masks = torch.zeros((batch_size, self.n_actions - 1), dtype=torch.bool, device=states.tensor.device) + backward_masks = torch.zeros((batch_size, self.n_actions), dtype=torch.bool, device=states.tensor.device) # Find which sequences can go backward (not at initial state) can_go_back = ~at_initial @@ -129,14 +136,8 @@ def _backward_action_masks(self, states: DiscreteStates) -> torch.Tensor: # Get the last token ID for each sequence that can go backward last_token_indices = seq_lens[can_go_back] - 1 batch_indices = torch.arange(batch_size, device=states.tensor.device)[can_go_back] - last_token_ids = states.tensor[batch_indices, last_token_indices] - - # Ensure token IDs are valid for backward actions (exclude exit action) - valid_token_mask = last_token_ids < self.n_actions - 1 - if valid_token_mask.any(): - valid_batch_indices = batch_indices[valid_token_mask] - valid_token_ids = last_token_ids[valid_token_mask] - backward_masks[valid_batch_indices, valid_token_ids] = True + last_token_ids = states.tensor[batch_indices, last_token_indices] + backward_masks[batch_indices, last_token_ids] = True return backward_masks def update_masks(self, states: DiscreteStates) -> None: # type: ignore[override] @@ -264,51 +265,16 @@ def to_probability_distribution( logits = module_output.clone() if self.is_backward: - # Backward masks exclude the exit action (last action) + # Backward masks masks = states.backward_masks - # Apply masks to all actions except the last one (exit action) - logits[:, :-1][~masks] = -float("inf") - # Always mask out the exit action for backward steps - logits[:, -1] = -float("inf") + logits[~masks] = -float("inf") else: # Forward masks include all actions masks = states.forward_masks logits[~masks] = -float("inf") - # Check for any completely invalid states and fix them - valid_mask_counts = masks.sum(dim=-1) - invalid_idx = valid_mask_counts == 0 - - if torch.any(invalid_idx): - print(f"Warning: Found {invalid_idx.sum().item()} states with no valid actions in {'backward' if self.is_backward else 'forward'} mode") - # Force at least one action to be valid - if self.is_backward: - # For backward, allow the first non-exit action (typically action 0) - masks[invalid_idx, 0] = True - logits[invalid_idx, :-1] = -float("inf") # mask out all non-exit actions - logits[invalid_idx, 0] = 0.0 # set action 0 to be valid - logits[invalid_idx, -1] = -float("inf") # ensure exit action remains masked - else: - # For forward, allow EOS token - masks[invalid_idx, self.tokenizer.eos_token_id] = True - logits[invalid_idx, :] = -float("inf") # mask out all actions - logits[invalid_idx, self.tokenizer.eos_token_id] = 0.0 # allow EOS - - if not self.is_backward and sf_bias != 0.0: - logits[:, -1] -= sf_bias # usually bias exit action - - if temperature != 1.0: - logits = logits / temperature - - probs = torch.softmax(logits, dim=-1) - - # Ensure probabilities are valid - probs = torch.clamp(probs, min=1e-8) - probs = probs / probs.sum(dim=-1, keepdim=True) - # Create a custom distribution that samples with the right shape - categorical = torch.distributions.Categorical(probs=probs) - + categorical = torch.distributions.Categorical(logits=logits) class ShapedCategorical: def __init__(self, cat_dist): self.cat_dist = cat_dist @@ -327,70 +293,73 @@ def log_prob(self, value): return ShapedCategorical(categorical) -def evaluate_model(env, sampler, tokenizer, n_samples=100, step_name="Evaluation"): +def evaluate_model(env, trajectories, tokenizer, step_name="Evaluation"): """Evaluate the model by sampling trajectories and analyzing generated numbers.""" - print(f"\n--- {step_name} ---") - print(f"Sampling {n_samples} trajectories for evaluation...") + final_states = trajectories.terminating_states - with torch.no_grad(): - trajectories = sampler.sample_trajectories(env, n=n_samples) - final_states = trajectories.terminating_states - - numbers = [] - for i in range(len(final_states)): - state_tensor = final_states.tensor[i] - # Identify non-pad tokens - non_pad = state_tensor != tokenizer.pad_token_id - seq = state_tensor[non_pad] - - # Remove prompt and EOS - prompt_len = env.prompt_tokens.shape[1] - if len(seq) <= prompt_len + 1: - continue - generated_tokens = seq[prompt_len:-1] # exclude prompt and eos - decoded_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) - try: - number = int(decoded_text.strip()) - if 0 <= number <= 100: # Only count valid numbers in range - numbers.append(number) - except (ValueError, IndexError): - pass - - print(f"Generated valid numbers: {numbers[:20]}...") # Show first 20 - - results = {} - if numbers: - counts = np.bincount(numbers, minlength=101) - valid_counts = counts[:101] # Only 0-100 - success_rate = len(numbers) / n_samples - unique_numbers = np.count_nonzero(valid_counts) - mean_val = np.mean(numbers) - std_val = np.std(numbers) - - print(f"Number of valid samples: {len(numbers)} / {n_samples} ({100*success_rate:.1f}%)") - print(f"Unique numbers generated: {unique_numbers} / 101") - print(f"Mean: {mean_val:.1f}, Std: {std_val:.1f}") - - results = { - 'valid_numbers': numbers, - 'success_rate': success_rate, - 'unique_count': unique_numbers, - 'mean': mean_val, - 'std': std_val, - 'total_valid': len(numbers) - } - else: - print("No valid numbers generated.") - results = { - 'valid_numbers': [], - 'success_rate': 0.0, - 'unique_count': 0, - 'mean': 0.0, - 'std': 0.0, - 'total_valid': 0 - } + numbers = [] + for i in range(len(final_states)): + state_tensor = final_states.tensor[i] + # Identify non-pad tokens + non_pad = state_tensor != tokenizer.pad_token_id + seq = state_tensor[non_pad] + + # Remove prompt and EOS + prompt_len = env.prompt_tokens.shape[1] + if len(seq) <= prompt_len + 1: + continue + generated_tokens = seq[prompt_len:-1] # exclude prompt and eos + decoded_text = tokenizer.decode(generated_tokens, skip_special_tokens=True) + try: + number = int(decoded_text.strip()) + if 0 <= number <= 100: # Only count valid numbers in range + numbers.append(number) + except (ValueError, IndexError): + pass + + results = {} + total_trajectories = len(trajectories) + + if numbers: + counts = np.bincount(numbers, minlength=101) + valid_counts = counts[:101] # Only 0-100 + success_rate = len(numbers) / total_trajectories + unique_numbers = np.count_nonzero(valid_counts) + mean_val = np.mean(numbers) + std_val = np.std(numbers) - return results + results = { + 'valid_numbers': numbers, + 'success_rate': success_rate, + 'unique_count': unique_numbers, + 'mean': mean_val, + 'std': std_val, + 'total_valid': len(numbers), + 'total_trajectories': total_trajectories, + 'step_name': step_name + } + else: + results = { + 'valid_numbers': [], + 'success_rate': 0.0, + 'unique_count': 0, + 'mean': 0.0, + 'std': 0.0, + 'total_valid': 0, + 'total_trajectories': total_trajectories, + 'step_name': step_name + } + + return results + + +def get_lambda_lr_scheduler(optimizer, num_warmup_steps, num_training_steps): + def get_lr_mult_at_step(step): + if step <= num_warmup_steps: + return step / num_warmup_steps + return max((num_training_steps - step) / (num_training_steps - num_warmup_steps), 0) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr_mult_at_step) def main(): @@ -399,27 +368,30 @@ def main(): parser = argparse.ArgumentParser(description="Train GFlowNet to generate random numbers 0-100") parser.add_argument("--device", default="auto", help="Device to use (auto, cpu, cuda)") parser.add_argument("--model_name", default="gpt2", help="Model name from HuggingFace") - parser.add_argument("--n_steps", type=int, default=200, help="Number of training steps") - parser.add_argument("--batch_size", type=int, default=128, help="Batch size for training") - parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") + parser.add_argument("--n_steps", type=int, default=400, help="Number of training steps") + parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training") + parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") parser.add_argument("--max_length", type=int, default=2, help="Max tokens to generate") parser.add_argument("--eval_samples", type=int, default=100, help="Number of samples for evaluation") # LoRA-specific arguments - parser.add_argument("--use_lora", action="store_true", default=False, help="Use LoRA for parameter-efficient fine-tuning") - parser.add_argument("--lora_r", type=int, default=32, help="LoRA rank") - parser.add_argument("--lora_alpha", type=int, default=16, help="LoRA alpha scaling parameter") + parser.add_argument("--use_lora", action="store_true", default=True, help="Use LoRA for parameter-efficient fine-tuning") + parser.add_argument("--lora_r", type=int, default=512, help="LoRA rank") + parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha scaling parameter") parser.add_argument("--lora_dropout", type=float, default=0.1, help="LoRA dropout probability") parser.add_argument("--target_modules", nargs="+", default=["c_attn", "c_proj"], help="Target modules for LoRA adaptation (default for GPT-2)") # Optimizer and scheduler arguments - parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for AdamW optimizer") - parser.add_argument("--warmup_steps", type=int, default=100, help="Number of warmup steps for learning rate scheduler") + parser.add_argument("--weight_decay", type=float, default=0.00, help="Weight decay for AdamW optimizer") + parser.add_argument("--warmup_steps", type=int, default=10, help="Number of warmup steps for learning rate scheduler") parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient norm for clipping") - parser.add_argument("--scheduler_type", default="cosine", choices=["cosine", "linear", "constant"], + parser.add_argument("--scheduler_type", default="lambda_lr", choices=["cosine", "linear", "constant", "lambda_lr"], help="Learning rate scheduler type") + # Replay buffer arguments + parser.add_argument("--use_buffer", action="store_true", default=True, help="Whether to use replay buffer for training stability") + args = parser.parse_args() # Device setup @@ -472,6 +444,16 @@ def main(): gflownet = TBGFlowNet(pf_module, pb_module) sampler = Sampler(pf_module) + + # Initialize replay buffer if requested + replay_buffer = None + if args.use_buffer: + replay_buffer = ReplayBuffer( + env, + capacity=args.batch_size * 4, # Use 4x batch size as capacity + prioritized_capacity=True, + prioritized_sampling=True, + ) # Set up optimizer and scheduler for trainable parameters trainable_params = [p for p in gflownet.parameters() if p.requires_grad] @@ -498,6 +480,12 @@ def main(): num_warmup_steps=args.warmup_steps, num_training_steps=args.n_steps ) + elif args.scheduler_type == "lambda_lr": + scheduler = get_lambda_lr_scheduler( + optimizer, + num_warmup_steps=args.warmup_steps, + num_training_steps=args.n_steps + ) else: # constant scheduler = None @@ -507,15 +495,51 @@ def main(): print(f"Using AdamW optimizer with lr={args.lr}, weight_decay={args.weight_decay}") print(f"Learning rate scheduler: {args.scheduler_type}, warmup_steps={args.warmup_steps}") print(f"Gradient clipping: max_norm={args.max_grad_norm}") + if args.use_buffer: + print(f"Using replay buffer with capacity: {args.batch_size * 4}") + else: + print("Not using replay buffer (--use_buffer flag not set)") # Pre-training evaluation print("Evaluating model performance before training...") - pre_training_results = evaluate_model(env, sampler, tokenizer, args.eval_samples, "Pre-training Evaluation") + with torch.no_grad(): + trajectories = sampler.sample_trajectories(env, n=args.eval_samples) + pre_training_results = evaluate_model(env, trajectories, tokenizer, "Pre-training Evaluation") + print(f"Generated valid numbers: {pre_training_results['valid_numbers'][:20]}...") # Show first 20 + success_rate = pre_training_results['success_rate'] + unique_count = pre_training_results['unique_count'] + mean_val = pre_training_results['mean'] + std_val = pre_training_results['std'] + total_valid = pre_training_results['total_valid'] + print(f"Number of valid samples: {total_valid} / {len(trajectories)} ({100*success_rate:.1f}%)") + print(f"Unique numbers generated: {unique_count} / 101") + print(f"Mean: {mean_val:.1f}, Std: {std_val:.1f}") + + # Initialize loss tracking + loss_history = [] # Training Loop for step in range(args.n_steps): - trajectories = sampler.sample_trajectories(env, n=args.batch_size, save_logprobs=False) - loss = gflownet.loss(env, trajectories) + # Sample trajectories using gflownet for compatibility with replay buffer + trajectories = gflownet.sample_trajectories(env, n=args.batch_size, save_logprobs=True) + training_samples = gflownet.to_training_samples(trajectories) + + # Use replay buffer if enabled + if args.use_buffer and replay_buffer is not None: + with torch.no_grad(): + replay_buffer.add(training_samples) + # After some initial steps, use half fresh samples and half from buffer + if step > 10: + training_samples = training_samples[: args.batch_size // 2] + buffer_samples = replay_buffer.sample(n_samples=args.batch_size // 2) + training_samples.extend(buffer_samples) # type: ignore + + # Calculate loss with recalculated logprobs for buffer compatibility + recalculate_logprobs = args.use_buffer and replay_buffer is not None + loss = gflownet.loss(env, training_samples, recalculate_all_logprobs=recalculate_logprobs) + + # Store loss for plotting + loss_history.append(loss.item()) optimizer.zero_grad() loss.backward() @@ -532,13 +556,38 @@ def main(): if step % 1 == 0: current_lr = optimizer.param_groups[0]['lr'] - if args.max_grad_norm > 0: - print(f"Step {step:4d}: loss = {loss.item():.4f}, lr = {current_lr:.6f}, grad_norm = {grad_norm:.4f}") - else: - print(f"Step {step:4d}: loss = {loss.item():.4f}, lr = {current_lr:.6f}") + training_results = evaluate_model(env, trajectories, tokenizer, "Training Evaluation") + buffer_info = f", buffer_size = {len(replay_buffer) if replay_buffer else 0}" if args.use_buffer else "" + print(f"Step {step:4d}: loss = {loss.item():.4f}, lr = {current_lr:.6f}, grad_norm = {grad_norm:.4f}, success_rate = {training_results['success_rate']:.1%}, unique_numbers = {training_results['unique_count']}{buffer_info}") # Post-training evaluation - post_training_results = evaluate_model(env, sampler, tokenizer, args.eval_samples, "Post-training Evaluation") + with torch.no_grad(): + trajectories = sampler.sample_trajectories(env, n=args.eval_samples) + post_training_results = evaluate_model(env, trajectories, tokenizer, "Post-training Evaluation") + + # Create and save loss plot + plt.figure(figsize=(10, 6)) + plt.plot(range(len(loss_history)), loss_history, 'b-', linewidth=1.0) + plt.title('Training Loss Over Time') + plt.xlabel('Training Step') + plt.ylabel('Loss') + plt.grid(True, alpha=0.3) + plt.tight_layout() + + # Create plots directory if it doesn't exist + os.makedirs('plots', exist_ok=True) + + # Save the plot with a descriptive filename + plot_filename = f'plots/loss_plot_steps{args.n_steps}_bs{args.batch_size}_lr{args.lr}.png' + plt.savefig(plot_filename, dpi=300, bbox_inches='tight') + print(f"\nLoss plot saved to: {plot_filename}") + + # Show basic loss statistics + print(f"Final loss: {loss_history[-1]:.4f}") + print(f"Minimum loss: {min(loss_history):.4f} (at step {loss_history.index(min(loss_history))})") + print(f"Average loss: {np.mean(loss_history):.4f}") + + plt.close() # Close the figure to free memory # Compare results print("\n--- Training Results Comparison ---") From 95e7062df49496b89ba53f35bd71a64e95a12b22 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 2 Aug 2025 17:35:36 +0200 Subject: [PATCH 08/11] fix backward mask --- tutorials/examples/train_rng_gfn.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tutorials/examples/train_rng_gfn.py b/tutorials/examples/train_rng_gfn.py index 91bbeb2f..ba11c97c 100644 --- a/tutorials/examples/train_rng_gfn.py +++ b/tutorials/examples/train_rng_gfn.py @@ -113,7 +113,7 @@ def _forward_action_masks(self, states: DiscreteStates) -> torch.Tensor: seq_lens = (states.tensor != self.tokenizer.pad_token_id).sum(dim=1) # If the state is already full forbid everything but EOS. - full_idx = seq_lens >= self.total_length + full_idx = seq_lens >= self.total_length - 1 if full_idx.any(): masks[full_idx] = False masks[full_idx, self.tokenizer.eos_token_id] = True @@ -128,7 +128,7 @@ def _backward_action_masks(self, states: DiscreteStates) -> torch.Tensor: # Backward mask: only the last non-pad token can be removed (one True per row, rest False) batch_size = states.tensor.shape[0] - backward_masks = torch.zeros((batch_size, self.n_actions), dtype=torch.bool, device=states.tensor.device) + backward_masks = torch.zeros((batch_size, self.n_actions - 1), dtype=torch.bool, device=states.tensor.device) # Find which sequences can go backward (not at initial state) can_go_back = ~at_initial @@ -267,7 +267,9 @@ def to_probability_distribution( if self.is_backward: # Backward masks masks = states.backward_masks - logits[~masks] = -float("inf") + assert self.tokenizer.eos_token_id == logits.shape[1] - 1 + logits[:, :-1][~masks] = -float("inf") + logits[:, -1] = -float("inf") else: # Forward masks include all actions masks = states.forward_masks From 4de9ec6820207dd5b6922f7038ed655d65601002 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 2 Aug 2025 18:24:48 +0200 Subject: [PATCH 09/11] fixes --- tutorials/examples/train_rng_gfn.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tutorials/examples/train_rng_gfn.py b/tutorials/examples/train_rng_gfn.py index ba11c97c..81f88854 100644 --- a/tutorials/examples/train_rng_gfn.py +++ b/tutorials/examples/train_rng_gfn.py @@ -155,7 +155,7 @@ def step(self, states: DiscreteStates, actions: Actions) -> DiscreteStates: new_states_tensor = states.tensor.clone() pad_token_id = self.tokenizer.pad_token_id for idx in range(len(states)): - pad_positions = (new_states_tensor[idx] == pad_token_id).nonzero(as_tuple=False) + pad_positions = (new_states_tensor[idx] == pad_token_id).nonzero() if pad_positions.numel() == 0: continue # already full, should not happen thanks to masks first_pad = pad_positions[0].item() @@ -169,7 +169,7 @@ def backward_step(self, states: DiscreteStates, actions: Actions) -> DiscreteSta pad_token_id = self.tokenizer.pad_token_id new_states_tensor = states.tensor.clone() for idx in range(len(states)): - non_pad_positions = (new_states_tensor[idx] != pad_token_id).nonzero(as_tuple=False) + non_pad_positions = (new_states_tensor[idx] != pad_token_id).nonzero() if non_pad_positions.numel() == 0: continue last_idx = non_pad_positions[-1].item() @@ -240,9 +240,7 @@ def expected_output_dim(self): def forward(self, states: DiscreteStates): # type: ignore[override] input_ids = states.tensor - # Build an attention mask: 1 for non-pad. attention_mask = (input_ids != self.tokenizer.pad_token_id).long() - outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) # Select logits corresponding to the last *non-pad* token. @@ -326,14 +324,16 @@ def evaluate_model(env, trajectories, tokenizer, step_name="Evaluation"): counts = np.bincount(numbers, minlength=101) valid_counts = counts[:101] # Only 0-100 success_rate = len(numbers) / total_trajectories - unique_numbers = np.count_nonzero(valid_counts) + unique_numbers = np.arange(101)[valid_counts > 0] + unique_count = len(unique_numbers) mean_val = np.mean(numbers) std_val = np.std(numbers) results = { 'valid_numbers': numbers, 'success_rate': success_rate, - 'unique_count': unique_numbers, + 'unique_count': unique_count, + 'unique_numbers': unique_numbers, 'mean': mean_val, 'std': std_val, 'total_valid': len(numbers), @@ -371,9 +371,9 @@ def main(): parser.add_argument("--device", default="auto", help="Device to use (auto, cpu, cuda)") parser.add_argument("--model_name", default="gpt2", help="Model name from HuggingFace") parser.add_argument("--n_steps", type=int, default=400, help="Number of training steps") - parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training") + parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training") parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") - parser.add_argument("--max_length", type=int, default=2, help="Max tokens to generate") + parser.add_argument("--max_length", type=int, default=3, help="Max tokens to generate") parser.add_argument("--eval_samples", type=int, default=100, help="Number of samples for evaluation") # LoRA-specific arguments @@ -510,11 +510,13 @@ def main(): print(f"Generated valid numbers: {pre_training_results['valid_numbers'][:20]}...") # Show first 20 success_rate = pre_training_results['success_rate'] unique_count = pre_training_results['unique_count'] + unique_numbers = pre_training_results['unique_numbers'] mean_val = pre_training_results['mean'] std_val = pre_training_results['std'] total_valid = pre_training_results['total_valid'] print(f"Number of valid samples: {total_valid} / {len(trajectories)} ({100*success_rate:.1f}%)") print(f"Unique numbers generated: {unique_count} / 101") + print(f"Unique numbers: {unique_numbers}") print(f"Mean: {mean_val:.1f}, Std: {std_val:.1f}") # Initialize loss tracking @@ -560,7 +562,7 @@ def main(): current_lr = optimizer.param_groups[0]['lr'] training_results = evaluate_model(env, trajectories, tokenizer, "Training Evaluation") buffer_info = f", buffer_size = {len(replay_buffer) if replay_buffer else 0}" if args.use_buffer else "" - print(f"Step {step:4d}: loss = {loss.item():.4f}, lr = {current_lr:.6f}, grad_norm = {grad_norm:.4f}, success_rate = {training_results['success_rate']:.1%}, unique_numbers = {training_results['unique_count']}{buffer_info}") + print(f"Step {step:4d}: loss = {loss.item():.4f}, lr = {current_lr:.6f}, grad_norm = {grad_norm:.4f}, success_rate = {training_results['success_rate']:.1%}, unique_numbers = {training_results['unique_count']}, unique_numbers = {training_results['unique_numbers']}{buffer_info}") # Post-training evaluation with torch.no_grad(): From a3a5f69fdc27d198b3fd9d50503fc61e189a5774 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 2 Aug 2025 19:06:13 +0200 Subject: [PATCH 10/11] add pre-filling --- tutorials/examples/train_rng_gfn.py | 151 +++++++++++++++++++++++++++- 1 file changed, 147 insertions(+), 4 deletions(-) diff --git a/tutorials/examples/train_rng_gfn.py b/tutorials/examples/train_rng_gfn.py index 81f88854..b986863a 100644 --- a/tutorials/examples/train_rng_gfn.py +++ b/tutorials/examples/train_rng_gfn.py @@ -47,6 +47,7 @@ from gfn.gflownet.trajectory_balance import TBGFlowNet from gfn.samplers import Sampler from gfn.containers import ReplayBuffer +from gfn.containers.trajectories import Trajectories class RNGEnv(DiscreteEnv): @@ -218,6 +219,142 @@ def log_reward(self, states: DiscreteStates) -> torch.Tensor: return rewards +class ExpertReplayBuffer(ReplayBuffer): + """Custom replay buffer that is prefilled with expert trajectories and has no-op add operation.""" + + def __init__(self, env, expert_trajectories, capacity=1000, prioritized_capacity=True, prioritized_sampling=True): + """Initialize with expert trajectories. + + Args: + env: The environment + expert_trajectories: Trajectories object containing expert demonstrations + capacity: Buffer capacity (ignored since we use fixed expert data) + prioritized_capacity: Whether to use prioritized capacity + prioritized_sampling: Whether to use prioritized sampling + """ + super().__init__(env, capacity, prioritized_capacity, prioritized_sampling) + + # Initialize the buffer with expert trajectories + self.training_objects = expert_trajectories + print(f"ExpertReplayBuffer initialized with {len(expert_trajectories)} expert trajectories") + + def add(self, training_objects): + """No-op add operation - keeps the expert trajectories unchanged.""" + # Do nothing - we want to keep only the expert trajectories + pass + + def __len__(self): + """Return the number of expert trajectories.""" + return len(self.training_objects) if self.training_objects is not None else 0 + + +def generate_expert_trajectories(env, tokenizer): + """Generate expert trajectories for numbers 0-100.""" + from gfn.samplers import Sampler + + # We'll create a simple expert policy and use it to generate trajectories + class ExpertPolicy: + def __init__(self, env, tokenizer, target_numbers): + self.env = env + self.tokenizer = tokenizer + self.target_numbers = target_numbers + self.current_targets = None + + def __call__(self, states, *args, **kwargs): + # Return logits that strongly favor the correct next token + batch_size = states.batch_shape[0] + logits = torch.full((batch_size, self.tokenizer.vocab_size), -10.0, device=states.device) + + for i in range(batch_size): + if self.current_targets is not None and i < len(self.current_targets): + target_num = self.current_targets[i] + # Check current state to see what token should come next + state_tensor = states.tensor[i] + prompt_len = self.env.prompt_tokens.shape[1] + + # Get non-pad positions after prompt + non_pad_mask = state_tensor[prompt_len:] != self.tokenizer.pad_token_id + num_generated = non_pad_mask.sum().item() + + # Get target tokens for this number + target_str = str(target_num) + target_tokens = self.tokenizer.encode(target_str, add_special_tokens=False) + + if num_generated < len(target_tokens): + # Next token should be the next digit + next_token = target_tokens[num_generated] + logits[i, next_token] = 10.0 # Strong preference + else: + # Should terminate with EOS + logits[i, self.tokenizer.eos_token_id] = 10.0 + + else: + # Default to EOS if no target specified + logits[i, self.tokenizer.eos_token_id] = 10.0 + + return logits + + def to_probability_distribution(self, states, logits, **kwargs): + # Apply forward mask and return categorical distribution + masks = states.forward_masks + logits_masked = logits.clone() + logits_masked[~masks] = -float("inf") + + # Create a custom distribution that samples with the right shape + categorical = torch.distributions.Categorical(logits=logits_masked) + class ShapedCategorical: + def __init__(self, cat_dist): + self.cat_dist = cat_dist + + def sample(self): + # Sample from categorical and reshape to (batch_size, 1) + samples = self.cat_dist.sample() + return samples.unsqueeze(-1) + + def log_prob(self, value): + # value should have shape (batch_size, 1), squeeze for categorical + if value.dim() > 1: + value = value.squeeze(-1) + return self.cat_dist.log_prob(value) + + return ShapedCategorical(categorical) + + # Use batched generation for efficiency + expert_trajectories = Trajectories(env) + batch_size = 32 # Process numbers in batches + + for start_num in range(0, 101, batch_size): + end_num = min(start_num + batch_size, 101) + current_batch_size = end_num - start_num + target_numbers = list(range(start_num, end_num)) + + # Create expert policy for this batch + expert_policy = ExpertPolicy(env, tokenizer, target_numbers) + expert_policy.current_targets = target_numbers + + # Create a dummy module that uses our expert policy + class ExpertModule: + def __init__(self, expert_policy): + self.expert_policy = expert_policy + self.is_backward = False + + def __call__(self, states, *args, **kwargs): + return self.expert_policy(states, *args, **kwargs) + + def to_probability_distribution(self, states, logits, **kwargs): + return self.expert_policy.to_probability_distribution(states, logits, **kwargs) + + expert_module = ExpertModule(expert_policy) + expert_sampler = Sampler(expert_module) + + # Sample trajectories for this batch + batch_trajectories = expert_sampler.sample_trajectories(env, n=current_batch_size) + expert_trajectories.extend(batch_trajectories) + + print(f"Generated {len(expert_trajectories)} expert trajectories") + return expert_trajectories + + class LLMGFNModule(DiscretePolicyEstimator): """GFNModule wrapping a pretrained LLM to act as a policy.""" @@ -450,9 +587,15 @@ def main(): # Initialize replay buffer if requested replay_buffer = None if args.use_buffer: - replay_buffer = ReplayBuffer( + # Generate expert trajectories for numbers 0-100 + print("Generating expert trajectories for numbers 0-100...") + expert_trajectories = generate_expert_trajectories(env, tokenizer) + + # Use custom replay buffer with expert trajectories + replay_buffer = ExpertReplayBuffer( env, - capacity=args.batch_size * 4, # Use 4x batch size as capacity + expert_trajectories, + capacity=args.batch_size * 4, # Use 4x batch size as capacity (ignored) prioritized_capacity=True, prioritized_sampling=True, ) @@ -498,7 +641,7 @@ def main(): print(f"Learning rate scheduler: {args.scheduler_type}, warmup_steps={args.warmup_steps}") print(f"Gradient clipping: max_norm={args.max_grad_norm}") if args.use_buffer: - print(f"Using replay buffer with capacity: {args.batch_size * 4}") + print(f"Using expert replay buffer with {len(replay_buffer)} expert trajectories") else: print("Not using replay buffer (--use_buffer flag not set)") @@ -562,7 +705,7 @@ def main(): current_lr = optimizer.param_groups[0]['lr'] training_results = evaluate_model(env, trajectories, tokenizer, "Training Evaluation") buffer_info = f", buffer_size = {len(replay_buffer) if replay_buffer else 0}" if args.use_buffer else "" - print(f"Step {step:4d}: loss = {loss.item():.4f}, lr = {current_lr:.6f}, grad_norm = {grad_norm:.4f}, success_rate = {training_results['success_rate']:.1%}, unique_numbers = {training_results['unique_count']}, unique_numbers = {training_results['unique_numbers']}{buffer_info}") + print(f"Step {step:4d}: loss = {loss.item():.4f}, lr = {current_lr:.6f}, grad_norm = {grad_norm:.4f}, success_rate = {training_results['success_rate']:.1%}, unique_count = {training_results['unique_count']}{buffer_info}") # Post-training evaluation with torch.no_grad(): From fb0f1fb76d92697b060af3ecd3144579ac82add9 Mon Sep 17 00:00:00 2001 From: younik Date: Sat, 13 Sep 2025 16:23:16 -0400 Subject: [PATCH 11/11] gpt-j + subtb --- tutorials/examples/train_rng_gfn.py | 97 +++++++++++++++++++---------- 1 file changed, 64 insertions(+), 33 deletions(-) diff --git a/tutorials/examples/train_rng_gfn.py b/tutorials/examples/train_rng_gfn.py index b986863a..cbf55da1 100644 --- a/tutorials/examples/train_rng_gfn.py +++ b/tutorials/examples/train_rng_gfn.py @@ -45,9 +45,12 @@ from gfn.modules import DiscretePolicyEstimator, GFNModule from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.gflownet.trajectory_balance import TBGFlowNet +from gfn.gflownet.sub_trajectory_balance import SubTBGFlowNet from gfn.samplers import Sampler from gfn.containers import ReplayBuffer from gfn.containers.trajectories import Trajectories +from gfn.utils.modules import MLP +from gfn.modules import ScalarEstimator class RNGEnv(DiscreteEnv): @@ -240,7 +243,6 @@ def __init__(self, env, expert_trajectories, capacity=1000, prioritized_capacity def add(self, training_objects): """No-op add operation - keeps the expert trajectories unchanged.""" - # Do nothing - we want to keep only the expert trajectories pass def __len__(self): @@ -379,6 +381,7 @@ def forward(self, states: DiscreteStates): # type: ignore[override] input_ids = states.tensor attention_mask = (input_ids != self.tokenizer.pad_token_id).long() outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) + outputs.logits = outputs.logits[..., :self.tokenizer.vocab_size] # remove extra tokens # Select logits corresponding to the last *non-pad* token. seq_lengths = attention_mask.sum(dim=1) # (batch,) @@ -482,6 +485,7 @@ def evaluate_model(env, trajectories, tokenizer, step_name="Evaluation"): 'valid_numbers': [], 'success_rate': 0.0, 'unique_count': 0, + 'unique_numbers': [], 'mean': 0.0, 'std': 0.0, 'total_valid': 0, @@ -506,9 +510,10 @@ def main(): parser = argparse.ArgumentParser(description="Train GFlowNet to generate random numbers 0-100") parser.add_argument("--device", default="auto", help="Device to use (auto, cpu, cuda)") - parser.add_argument("--model_name", default="gpt2", help="Model name from HuggingFace") + parser.add_argument("--model_name", default="EleutherAI/gpt-j-6B", help="Model name from HuggingFace") parser.add_argument("--n_steps", type=int, default=400, help="Number of training steps") - parser.add_argument("--batch_size", type=int, default=64, help="Batch size for training") + parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training") + parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Number of gradient accumulation steps") parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate") parser.add_argument("--max_length", type=int, default=3, help="Max tokens to generate") parser.add_argument("--eval_samples", type=int, default=100, help="Number of samples for evaluation") @@ -519,7 +524,7 @@ def main(): parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha scaling parameter") parser.add_argument("--lora_dropout", type=float, default=0.1, help="LoRA dropout probability") parser.add_argument("--target_modules", nargs="+", default=["c_attn", "c_proj"], - help="Target modules for LoRA adaptation (default for GPT-2)") + help="Target modules for LoRA adaptation (default for GPT-J)") # Optimizer and scheduler arguments parser.add_argument("--weight_decay", type=float, default=0.00, help="Weight decay for AdamW optimizer") @@ -538,6 +543,7 @@ def main(): device = 'cuda' if torch.cuda.is_available() else 'cpu' else: device = args.device + torch.set_default_device(device) print(f"Using device: {device}") # Model and tokenizer setup @@ -547,7 +553,7 @@ def main(): tokenizer.pad_token = tokenizer.eos_token # Load base model - model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device) + model = AutoModelForCausalLM.from_pretrained(args.model_name, use_safetensors=True).to(device) if args.use_lora: # Configure LoRA @@ -556,7 +562,7 @@ def main(): r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, - target_modules=args.target_modules, + #target_modules=args.target_modules, bias="none", ) @@ -580,8 +586,13 @@ def main(): state_dim = env.total_length pf_module = LLMGFNModule(model, tokenizer, state_dim, is_backward=False) pb_module = LLMGFNModule(model, tokenizer, state_dim, is_backward=True) + module_logF = MLP( + input_dim=state_dim, + output_dim=1, # Important for ScalarEstimators! + ) + logF_module = ScalarEstimator(module_logF) - gflownet = TBGFlowNet(pf_module, pb_module) + gflownet = SubTBGFlowNet(pf_module, pb_module, logF_module) #TBGFlowNet(pf_module, pb_module) sampler = Sampler(pf_module) # Initialize replay buffer if requested @@ -595,7 +606,6 @@ def main(): replay_buffer = ExpertReplayBuffer( env, expert_trajectories, - capacity=args.batch_size * 4, # Use 4x batch size as capacity (ignored) prioritized_capacity=True, prioritized_sampling=True, ) @@ -635,7 +645,10 @@ def main(): scheduler = None param_count = sum(p.numel() for p in trainable_params) - print(f"Training for {args.n_steps} steps with batch size {args.batch_size}") + effective_batch_size = args.batch_size * args.gradient_accumulation_steps + print(f"Training for {args.n_steps} steps with batch size {args.batch_size} (effective batch size: {effective_batch_size})") + if args.gradient_accumulation_steps > 1: + print(f"Using gradient accumulation: {args.gradient_accumulation_steps} steps per optimizer update") print(f"Optimizing {param_count:,} trainable parameters across {len(trainable_params)} parameter groups") print(f"Using AdamW optimizer with lr={args.lr}, weight_decay={args.weight_decay}") print(f"Learning rate scheduler: {args.scheduler_type}, warmup_steps={args.warmup_steps}") @@ -665,37 +678,52 @@ def main(): # Initialize loss tracking loss_history = [] - # Training Loop + # Training Loop with Gradient Accumulation + optimizer.zero_grad() + accumulated_loss = 0.0 + grad_norm = 0.0 + for step in range(args.n_steps): - # Sample trajectories using gflownet for compatibility with replay buffer - trajectories = gflownet.sample_trajectories(env, n=args.batch_size, save_logprobs=True) - training_samples = gflownet.to_training_samples(trajectories) - - # Use replay buffer if enabled - if args.use_buffer and replay_buffer is not None: - with torch.no_grad(): - replay_buffer.add(training_samples) - # After some initial steps, use half fresh samples and half from buffer - if step > 10: - training_samples = training_samples[: args.batch_size // 2] - buffer_samples = replay_buffer.sample(n_samples=args.batch_size // 2) - training_samples.extend(buffer_samples) # type: ignore + step_loss = 0.0 - # Calculate loss with recalculated logprobs for buffer compatibility - recalculate_logprobs = args.use_buffer and replay_buffer is not None - loss = gflownet.loss(env, training_samples, recalculate_all_logprobs=recalculate_logprobs) + # Accumulate gradients over multiple mini-batches + for accum_step in range(args.gradient_accumulation_steps): + # Sample trajectories using gflownet for compatibility with replay buffer + trajectories = gflownet.sample_trajectories(env, n=args.batch_size, save_logprobs=True) + training_samples = gflownet.to_training_samples(trajectories) + + # Use replay buffer if enabled + if args.use_buffer and replay_buffer is not None: + with torch.no_grad(): + replay_buffer.add(training_samples) + # After some initial steps, use half fresh samples and half from buffer + if step > 10: + training_samples = training_samples[: args.batch_size // 2] + buffer_samples = replay_buffer.sample(n_samples=args.batch_size // 2) + training_samples.extend(buffer_samples) # type: ignore + + # Calculate loss with recalculated logprobs for buffer compatibility + recalculate_logprobs = args.use_buffer and replay_buffer is not None + loss = gflownet.loss(env, training_samples, recalculate_all_logprobs=recalculate_logprobs) + + # Scale loss by accumulation steps to maintain same effective learning rate + loss = loss / args.gradient_accumulation_steps + step_loss += loss.item() + + # Backward pass (gradients accumulate) + loss.backward() - # Store loss for plotting - loss_history.append(loss.item()) - - optimizer.zero_grad() - loss.backward() + # Store total loss for this step for plotting + loss_history.append(step_loss) + accumulated_loss += step_loss # Gradient clipping for stability if args.max_grad_norm > 0: grad_norm = torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) + # Optimizer step after accumulating gradients optimizer.step() + optimizer.zero_grad() # Update learning rate scheduler if scheduler is not None: @@ -704,8 +732,10 @@ def main(): if step % 1 == 0: current_lr = optimizer.param_groups[0]['lr'] training_results = evaluate_model(env, trajectories, tokenizer, "Training Evaluation") + effective_batch_size = args.batch_size * args.gradient_accumulation_steps buffer_info = f", buffer_size = {len(replay_buffer) if replay_buffer else 0}" if args.use_buffer else "" - print(f"Step {step:4d}: loss = {loss.item():.4f}, lr = {current_lr:.6f}, grad_norm = {grad_norm:.4f}, success_rate = {training_results['success_rate']:.1%}, unique_count = {training_results['unique_count']}{buffer_info}") + accum_info = f", grad_accum = {args.gradient_accumulation_steps}, eff_bs = {effective_batch_size}" if args.gradient_accumulation_steps > 1 else "" + print(f"Step {step:4d}: loss = {step_loss:.4f}, lr = {current_lr:.6f}, grad_norm = {grad_norm:.4f}, success_rate = {training_results['success_rate']:.1%}, unique_count = {training_results['unique_count']}{buffer_info}{accum_info}") # Post-training evaluation with torch.no_grad(): @@ -725,7 +755,8 @@ def main(): os.makedirs('plots', exist_ok=True) # Save the plot with a descriptive filename - plot_filename = f'plots/loss_plot_steps{args.n_steps}_bs{args.batch_size}_lr{args.lr}.png' + effective_bs = args.batch_size * args.gradient_accumulation_steps + plot_filename = f'plots/loss_plot_steps{args.n_steps}_bs{args.batch_size}_effbs{effective_bs}_lr{args.lr}.png' plt.savefig(plot_filename, dpi=300, bbox_inches='tight') print(f"\nLoss plot saved to: {plot_filename}")