From 8348559f46f504d653ddbc0d0413df4bfecc860c Mon Sep 17 00:00:00 2001 From: hallerite Date: Sun, 28 Dec 2025 21:53:02 +0100 Subject: [PATCH 01/17] add general scoring functionality & OPD --- examples/opd/train_opd_gsm8k.py | 314 ++++++++++++++++++ src/ludic/agents/base_agent.py | 31 +- src/ludic/interaction/multi_agent.py | 1 + src/ludic/interaction/single_agent.py | 2 + src/ludic/training/__init__.py | 4 + src/ludic/training/algorithm.py | 73 ++++ src/ludic/training/batching/micro_batching.py | 39 ++- src/ludic/training/batching/rollout_engine.py | 46 ++- .../training/batching/synced_batching.py | 4 + src/ludic/training/distillation.py | 149 +++++++++ src/ludic/training/loss.py | 85 +++++ src/ludic/training/scoring.py | 166 +++++++++ src/ludic/training/types.py | 26 ++ src/ludic/types.py | 1 + 14 files changed, 931 insertions(+), 10 deletions(-) create mode 100644 examples/opd/train_opd_gsm8k.py create mode 100644 src/ludic/training/distillation.py create mode 100644 src/ludic/training/scoring.py diff --git a/examples/opd/train_opd_gsm8k.py b/examples/opd/train_opd_gsm8k.py new file mode 100644 index 0000000..fa650c0 --- /dev/null +++ b/examples/opd/train_opd_gsm8k.py @@ -0,0 +1,314 @@ +""" +On-Policy Distillation (OPD) training on GSM8K using vLLM. + +This example demonstrates on-policy distillation where: + - A student model samples trajectories + - A teacher model provides per-token logprobs as dense supervision + - Training minimizes reverse KL divergence: KL(student || teacher) + +This combines the benefits of: + - On-policy learning (student samples from itself) + - Dense supervision (per-token feedback, not sparse rewards) + +The key insight: teacher logprobs are an *intrinsic scorer* attached to the Agent. +The scorer runs during Agent.act() and scores flow through to training. + +Reference: https://thinkingmachines.ai/blog/on-policy-distillation + +Usage: + python train_opd_gsm8k.py \ + --student-model Qwen/Qwen3-8B-Base \ + --teacher-model Qwen/Qwen3-32B \ + --limit 1000 + +Requirements: + - vLLM servers running for student (port 8000) and teacher (port 8001) +""" + +from __future__ import annotations + +import argparse +import queue +from typing import List, Dict, Any + +import torch +from datasets import load_dataset # type: ignore +from transformers import AutoModelForCausalLM, AutoTokenizer + +from ludic.agent import Agent +from ludic.context import FullDialog +from ludic.inference import InferenceSpec, SamplingParams, ReturnSpec, HFChatTemplate +from ludic.interaction import SingleAgentSyncProtocol +from ludic.parsers import boxed_parser +from ludic.training import ( + RolloutEngine, + RolloutBatchSource, + Trainer, + TrainerConfig, + make_dataset_queue_requests_fn, + make_opd, + RequestsExhausted, +) +from ludic.training import Reducer, PrintLogger, default_reducers +from ludic.training.scoring import make_vllm_teacher_scorer + +# Try to import environments +try: + from environments.gsm8k import GSM8KEnv +except ImportError: + # Fallback: define a minimal GSM8K env + from ludic.envs import DatasetQAEnv + + class GSM8KEnv(DatasetQAEnv): + def __init__(self, sample: Dict[str, Any], system_prompt: str = ""): + super().__init__( + question=sample["question"], + ground_truth=self._extract_answer(sample["answer"]), + system_prompt=system_prompt or "Solve the following problem step by step. Put your final answer in \\boxed{}.", + ) + + @staticmethod + def _extract_answer(answer_text: str) -> str: + # GSM8K answers have format "...\n#### answer" + if "####" in answer_text: + return answer_text.split("####")[-1].strip() + return answer_text.strip() + + +def load_gsm8k(split: str, limit: int | None) -> List[Dict[str, Any]]: + """Load GSM8K dataset samples.""" + ds = load_dataset("gsm8k", "main", split=split) + samples: List[Dict[str, Any]] = [] + for idx, row in enumerate(ds): + samples.append( + { + "question": row["question"], + "answer": row["answer"], + "id": row.get("id", idx), + } + ) + if limit is not None and len(samples) >= limit: + break + return samples + + +def main(): + parser = argparse.ArgumentParser(description="OPD training on GSM8K") + + # Model configuration + parser.add_argument("--student-model", default="Qwen/Qwen3-8B-Base", + help="Student model name/path") + parser.add_argument("--teacher-model", default="Qwen/Qwen3-32B", + help="Teacher model name/path") + + # vLLM server configuration + parser.add_argument("--student-host", default="127.0.0.1") + parser.add_argument("--student-port", type=int, default=8000) + parser.add_argument("--teacher-host", default="127.0.0.1") + parser.add_argument("--teacher-port", type=int, default=8001) + + # Data configuration + parser.add_argument("--split", default="train") + parser.add_argument("--limit", type=int, default=None, + help="Limit training samples (None = use all)") + + # Training configuration + parser.add_argument("--rollouts-per-update", type=int, default=64, + help="Number of rollouts per training step") + parser.add_argument("--train-steps", type=int, default=100, + help="Number of training steps") + parser.add_argument("--max-seq-len", type=int, default=2048, + help="Max sequence length") + parser.add_argument("--micro-token-budget", type=int, default=32768, + help="Max padded tokens per micro-batch") + parser.add_argument("--max-completion-tokens", type=int, default=1024, + help="Max completion tokens per rollout") + parser.add_argument("--temperature", type=float, default=1.0, + help="Sampling temperature") + parser.add_argument("--concurrency", type=int, default=32, + help="Rollout concurrency") + + # OPD-specific configuration + parser.add_argument("--kl-coeff", type=float, default=1.0, + help="Coefficient for reverse KL loss") + parser.add_argument("--length-normalize", action="store_true", + help="Normalize loss by sequence length") + + # System prompt + parser.add_argument("--system-prompt", type=str, + default="First, think step by step. Then put your final answer inside \\boxed{...}.") + + args = parser.parse_args() + + # Load training data + print(f"Loading GSM8K {args.split} split...") + train_samples = load_gsm8k(args.split, args.limit) + if not train_samples: + raise SystemExit("No GSM8K samples loaded.") + print(f"Loaded {len(train_samples)} training samples") + + # Create sample queue + samples_q: queue.Queue = queue.Queue() + for idx, s in enumerate(train_samples): + samples_q.put((idx, s)) + + # Load tokenizer and model + print(f"Loading student model: {args.student_model}") + tokenizer = AutoTokenizer.from_pretrained(args.student_model) + model = AutoModelForCausalLM.from_pretrained( + args.student_model, + torch_dtype=torch.bfloat16, + ) + device = "cuda" if torch.cuda.is_available() else "cpu" + model.to(device) + + # Create vLLM clients + from ludic.inference import VLLMChatClient + from ludic.distributed.adapters import create_vllm_publisher + + # Student client for sampling + client = VLLMChatClient( + host=args.student_host, + port=args.student_port, + enable_weight_updates=True, + ) + publisher = create_vllm_publisher(client) + + # Teacher scorer - computes per-token logprobs during Agent.act() + teacher_scorer = make_vllm_teacher_scorer( + base_url=f"http://{args.teacher_host}:{args.teacher_port}", + model=args.teacher_model, + ) + + chat_template = HFChatTemplate(tokenizer) + + # Environment and protocol registries + env_registry = { + "gsm8k": lambda sample: GSM8KEnv( + sample=sample, + system_prompt=args.system_prompt, + ) + } + + # Agent with teacher scorer - scores flow through to training + def protocol_factory(): + return SingleAgentSyncProtocol( + agent=Agent( + client=client, + model=args.student_model, + ctx=FullDialog(), + parser=boxed_parser, + chat_template=chat_template, + scorers=[teacher_scorer], # OPD: teacher provides per-token scores + ) + ) + + protocol_registry = {"single_agent": protocol_factory} + + # Create OPD algorithm + algo = make_opd( + kl_coeff=args.kl_coeff, + length_normalize=args.length_normalize, + name="opd", + ) + + # Create rollout engine + engine = RolloutEngine( + env_registry=env_registry, + protocol_registry=protocol_registry, + jsonl_path="opd_rollouts.jsonl", + ) + + # Create inference spec + train_inference = InferenceSpec( + sampling=SamplingParams( + temperature=args.temperature, + max_tokens=args.max_completion_tokens, + ), + return_=ReturnSpec.for_rl(top_logprobs_k=1), + ) + + # Create requests function + requests_fn = make_dataset_queue_requests_fn( + samples_q, + batch_size=args.rollouts_per_update, + env_kind="gsm8k", + protocol_kind="single_agent", + inference=train_inference, + protocol_kwargs={}, + request_meta_fn=lambda idx, sample: { + "sample_index": idx, + "question_id": sample.get("id", idx), + }, + env_seed_fn=lambda idx, _sample: idx, + sampling_seed_fn=lambda idx, _sample: idx, + ) + + # Create batch source (no teacher_client needed - it's in the Agent!) + batch_source = RolloutBatchSource( + orchestrator=engine, + credit_assigner=algo.credit_assigner, + requests_fn=requests_fn, + max_steps=1, + concurrency=args.concurrency, + ) + + # Trainer config + cfg = TrainerConfig( + model_device=device, + max_seq_len=args.max_seq_len, + micro_token_budget=args.micro_token_budget, + max_grad_norm=0.5, + pad_token_id=tokenizer, + ) + + # Reducers for logging + reducers = { + **default_reducers(), + "correct_rate": Reducer( + kind="count_true", + source="correct", + normalize_by="rollouts", + ), + } + + # Logger + logger_keys = [ + "train/loss", + "train/reverse_kl_mean", + "train/avg_total_reward", + "train/correct_rate", + "train/avg_completion_length", + "train/num_samples", + ] + train_logger = PrintLogger(prefix="[opd]", keys=logger_keys, precision=4) + + # Create trainer + trainer = Trainer( + model=model, + algo=algo, + batch_source=batch_source, + publisher=publisher, + enable_gradient_checkpointing=True, + cfg=cfg, + train_logger=train_logger, + reducers=reducers, + ) + + # Train + print(f"\nStarting OPD training for {args.train_steps} steps...") + print(f" Student: {args.student_model}") + print(f" Teacher: {args.teacher_model}") + print(f" KL coefficient: {args.kl_coeff}") + print() + + try: + trainer.train_sync(args.train_steps) + except RequestsExhausted: + print("No more training samples; stopping.") + + print("\nTraining complete!") + + +if __name__ == "__main__": + main() diff --git a/src/ludic/agents/base_agent.py b/src/ludic/agents/base_agent.py index 55bea28..acd722a 100644 --- a/src/ludic/agents/base_agent.py +++ b/src/ludic/agents/base_agent.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Mapping, TYPE_CHECKING +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Mapping, TYPE_CHECKING, Union import torch @@ -13,6 +13,7 @@ if TYPE_CHECKING: from ludic.inference.chat_template import ChatTemplate + from ludic.training.scoring import IntrinsicScorer, TokenLevelScorer, ActionLevelScorer _DEFAULT_INCOMPLETE_FEEDBACK = ( "Your response was cut off because it exceeded the token limit. " @@ -65,6 +66,7 @@ class AgentActStep: loop_index: int tool_calls: Optional[List[Dict[str, Any]]] = None tool_results: Optional[List[Dict[str, Any]]] = None + intrinsic_scores: Dict[str, Union[List[float], float]] = field(default_factory=dict) @dataclass @@ -97,6 +99,7 @@ def __init__( incomplete_completion_penalty: float = -0.1, incomplete_completion_feedback: str = _DEFAULT_INCOMPLETE_FEEDBACK, chat_template: Optional["ChatTemplate"] = None, + scorers: Optional[List["IntrinsicScorer"]] = None, ) -> None: """ Initializes the Agent. @@ -113,6 +116,9 @@ def __init__( its completion is cut off. chat_template: ChatTemplate for token-in mode. If None, the agent will try to build an HFChatTemplate from client.tokenizer. + scorers: Optional list of intrinsic scorers (TokenLevelScorer or + ActionLevelScorer) to run after each action. Scores are attached + to AgentActStep.intrinsic_scores. """ self._client = client self._model = model @@ -132,6 +138,7 @@ def __init__( chat_template = HFChatTemplate(tokenizer) self._chat_template = chat_template + self._scorers: List["IntrinsicScorer"] = scorers or [] self.last_info: Dict[str, Any] = {} async def _infer_once( @@ -252,6 +259,25 @@ async def act( # 5. Parse (format the raw text action) parse_result = self._parser(raw_action) + # 6. Run intrinsic scorers + intrinsic_scores: Dict[str, Union[List[float], float]] = {} + if self._scorers: + from ludic.training.scoring import TokenLevelScorer, ActionLevelScorer + + prompt_text = self._chat_template.apply( + messages, add_generation_prompt=True + ).prompt_text + + for scorer in self._scorers: + if isinstance(scorer, TokenLevelScorer): + scores = await scorer.score_tokens( + list(token_trace.completion_token_ids) + ) + intrinsic_scores[scorer.name] = scores + elif isinstance(scorer, ActionLevelScorer): + score = await scorer.score_action(prompt_text, raw_action) + intrinsic_scores[scorer.name] = score + step = AgentActStep( prompt_messages=messages, action=raw_action, @@ -260,6 +286,7 @@ async def act( trace=token_trace, action_target="env", loop_index=0, + intrinsic_scores=intrinsic_scores, ) return AgentActResult(steps=[step]) diff --git a/src/ludic/interaction/multi_agent.py b/src/ludic/interaction/multi_agent.py index 788edea..51af49b 100644 --- a/src/ludic/interaction/multi_agent.py +++ b/src/ludic/interaction/multi_agent.py @@ -172,6 +172,7 @@ async def run( turn_id=turn_ids[agent_id], tool_calls=act_step.tool_calls, tool_results=act_step.tool_results, + intrinsic_scores=act_step.intrinsic_scores, ) collector.add(agent_id, agent_step) step_indices[agent_id] += 1 diff --git a/src/ludic/interaction/single_agent.py b/src/ludic/interaction/single_agent.py index af6543d..76b65d9 100644 --- a/src/ludic/interaction/single_agent.py +++ b/src/ludic/interaction/single_agent.py @@ -179,6 +179,7 @@ async def run( turn_id=turn_id, tool_calls=act_step.tool_calls, tool_results=act_step.tool_results, + intrinsic_scores=act_step.intrinsic_scores, ) steps.append(agent_step) turn_agent_step_ids.append(agent_step.id) @@ -255,6 +256,7 @@ async def run( turn_id=turn_id, tool_calls=act_step.tool_calls, tool_results=act_step.tool_results, + intrinsic_scores=act_step.intrinsic_scores, ) steps.append(agent_step) turn_agent_step_ids.append(agent_step.id) diff --git a/src/ludic/training/__init__.py b/src/ludic/training/__init__.py index ffd0a74..a47e3f0 100644 --- a/src/ludic/training/__init__.py +++ b/src/ludic/training/__init__.py @@ -26,6 +26,7 @@ make_gspo, make_cispo, make_sft, + make_opd, ) from .credit_assignment import ( GroupNormalizedReturn, @@ -42,6 +43,7 @@ TokenClippedSurrogateLoss, CISPOLoss, KLLoss, + ReverseKLLoss, EntropyBonus, LossTerm, CompositeLoss, @@ -84,6 +86,7 @@ "make_gspo", "make_cispo", "make_sft", + "make_opd", # Credit assignment "GroupNormalizedReturn", "MonteCarloReturn", @@ -98,6 +101,7 @@ "TokenClippedSurrogateLoss", "CISPOLoss", "KLLoss", + "ReverseKLLoss", "EntropyBonus", "LossTerm", "CompositeLoss", diff --git a/src/ludic/training/algorithm.py b/src/ludic/training/algorithm.py index 720d447..9ddb823 100644 --- a/src/ludic/training/algorithm.py +++ b/src/ludic/training/algorithm.py @@ -15,6 +15,7 @@ TokenClippedSurrogateLoss, CISPOLoss, MaskedCausalLMCrossEntropyLoss, + ReverseKLLoss, ) from ludic.training.credit_assignment import MonteCarloReturn, GroupNormalizedReturn, ConstantCredit @@ -450,3 +451,75 @@ def make_sft( credit_assigner=credit_assigner, loss=loss, ) + + +# --------------------------------------------------------------------------- +# On-Policy Distillation (OPD) +# --------------------------------------------------------------------------- + + +def make_opd( + *, + kl_coeff: float = 1.0, + length_normalize: bool = False, + name: str = "opd", +) -> RLAlgorithm: + """ + On-Policy Distillation (OPD). + + Dense per-token supervision from a teacher model. The student samples + trajectories, teacher provides logprobs, and training minimizes reverse KL. + + Core idea: Sample from student, grade each token with teacher logprobs. + This combines on-policy learning (samples from student) with dense + supervision (per-token teacher signal), achieving better compute + efficiency than sparse RL rewards. + + Credit assignment: ConstantCredit(1.0) - the per-step "advantage" is + implicit in the loss (negative reverse KL per token). + + Loss: ReverseKLLoss - minimizes KL(student || teacher) per token. + + Prerequisites: + - Rollouts must include teacher logprobs via one of: + - RolloutEngine.generate_batch(teacher_client=...) + - External post-processing that populates SAWItem.attachments.teacher_logps + - Collator must extract teacher_logps into batch["teacher_logps"] + + Args: + kl_coeff: Coefficient for KL loss. Higher values push the student + harder towards the teacher's distribution. Default 1.0. + length_normalize: If True, divide per-sample loss by number of + action tokens. Useful when sequences have varying lengths. + name: Algorithm name for logging/metrics. + + Example: + ```python + from ludic.training import RolloutBatchSource, Trainer, make_opd + from ludic.training.distillation import TinkerTeacherClient + + # Create teacher client + teacher = TinkerTeacherClient(sampling_client=teacher_sampling_client) + + # Create batch source with teacher + batch_source = RolloutBatchSource( + engine=engine, + make_requests=make_requests_fn, + credit_assigner=make_opd().credit_assigner, + teacher_client=teacher, + ) + + # Train with OPD + trainer = Trainer(model=model, algorithm=make_opd(), ...) + ``` + + Reference: https://thinkingmachines.ai/blog/on-policy-distillation + """ + credit_assigner: CreditAssigner = ConstantCredit(value=1.0) + loss: Loss = ReverseKLLoss(coeff=kl_coeff, length_normalize=length_normalize) + + return RLAlgorithm( + name=name, + credit_assigner=credit_assigner, + loss=loss, + ) diff --git a/src/ludic/training/batching/micro_batching.py b/src/ludic/training/batching/micro_batching.py index 98b27fb..f76aba2 100644 --- a/src/ludic/training/batching/micro_batching.py +++ b/src/ludic/training/batching/micro_batching.py @@ -6,7 +6,7 @@ import torch from torch import Tensor -from ludic.training.types import SAWItem, ActorTokenLogps, SampleAttachments +from ludic.training.types import SAWItem, ActorTokenLogps, TeacherLogprobs, SampleAttachments @dataclass(frozen=True) @@ -95,6 +95,28 @@ def collate_saw_items( batch["actor_logps"] = actor_logps_batch batch["old_logp_action"] = old_logp_action + + # teacher_logps is optional; required for on-policy distillation (OPD). + teachers = [it.teacher_logps for it in items] + if any(teacher is not None for teacher in teachers): + if any(teacher is None for teacher in teachers): + raise ValueError( + "Mixed presence of teacher_logps; either provide it for all samples or none." + ) + teacher_logps_batch = torch.zeros((batch_size, max_len), dtype=torch.float32, device=device) + for b, (it, teacher) in enumerate(zip(items, teachers)): + assert teacher is not None + token_logps = torch.as_tensor(teacher.token_logps, dtype=torch.float32, device=device) + positions = torch.nonzero(action_mask[b] > 0.0, as_tuple=False).flatten() + if token_logps.numel() != positions.numel(): + raise ValueError( + f"Length mismatch between teacher_logps ({token_logps.numel()}) " + f"and the number of action tokens ({positions.numel()})." + ) + teacher_logps_batch[b, positions] = token_logps + + batch["teacher_logps"] = teacher_logps_batch + return batch @@ -112,9 +134,20 @@ def _truncate_item(item: SAWItem, max_seq_len: int) -> SAWItem: prompt_tokens = len(input_ids) - action_tokens attachments = item.attachments + new_actor_logps = None + new_teacher_logps = None if attachments.actor_logps is not None: - token_logps = attachments.actor_logps.token_logps[:action_tokens] - attachments = SampleAttachments(actor_logps=ActorTokenLogps(token_logps=token_logps)) + new_actor_logps = ActorTokenLogps( + token_logps=attachments.actor_logps.token_logps[:action_tokens] + ) + if attachments.teacher_logps is not None: + new_teacher_logps = TeacherLogprobs( + token_logps=attachments.teacher_logps.token_logps[:action_tokens] + ) + attachments = SampleAttachments( + actor_logps=new_actor_logps, + teacher_logps=new_teacher_logps, + ) meta = dict(item.meta) meta["seq_len_truncated"] = True diff --git a/src/ludic/training/batching/rollout_engine.py b/src/ludic/training/batching/rollout_engine.py index d9deee0..98ab229 100644 --- a/src/ludic/training/batching/rollout_engine.py +++ b/src/ludic/training/batching/rollout_engine.py @@ -18,6 +18,7 @@ SAWItem, SAWBatch, ActorTokenLogps, + TeacherLogprobs, SampleAttachments, RolloutRequest, ProtocolSpec, @@ -281,6 +282,10 @@ def _build_turn_saw_item( logprobs_mode: Optional[bool] = None logprobs: List[float] = [] + # Collect intrinsic scores (token-level scores are concatenated across steps) + token_scores: Dict[str, List[float]] = {} + action_scores: Dict[str, float] = {} + def _extend(ids: Sequence[int], *, is_action: bool) -> None: input_ids.extend(ids) attention_mask.extend([1] * len(ids)) @@ -332,6 +337,17 @@ def _extend(ids: Sequence[int], *, is_action: bool) -> None: logprobs_mode = True logprobs.extend(completion_logprobs) + # Collect intrinsic scores from agent step + for name, scores in step.intrinsic_scores.items(): + if isinstance(scores, list): + # Token-level scores: concatenate across steps + if name not in token_scores: + token_scores[name] = [] + token_scores[name].extend(scores) + else: + # Action-level scores: use last step's score (or could sum/average) + action_scores[name] = float(scores) + if require_chosen_logprobs and not logprobs_mode: raise ValueError( f"Missing completion_logprobs for rollout {rollout.id}, turn {turn.turn_id!r}, " @@ -381,11 +397,25 @@ def _extend(ids: Sequence[int], *, is_action: bool) -> None: } ) - attachments = SampleAttachments() - if logprobs_mode: - attachments = SampleAttachments( - actor_logps=ActorTokenLogps(token_logps=logprobs) - ) + # Build attachments + actor_logps = ActorTokenLogps(token_logps=logprobs) if logprobs_mode else None + teacher_logps = None + if "teacher_logps" in token_scores: + teacher_logps = TeacherLogprobs(token_logps=token_scores["teacher_logps"]) + + attachments = SampleAttachments( + actor_logps=actor_logps, + teacher_logps=teacher_logps, + ) + + # Store other intrinsic scores in meta for now + # (could extend SampleAttachments later for more score types) + if token_scores: + meta["intrinsic_token_scores"] = { + k: v for k, v in token_scores.items() if k != "teacher_logps" + } + if action_scores: + meta["intrinsic_action_scores"] = action_scores return ( SAWItem( @@ -645,6 +675,12 @@ async def generate_batch( - If sample_filter is provided, it's applied after SAWItems are created. - Filter returns True to KEEP a sample, False to DROP it. - Use ludic.training.filters for common predicates. + + Intrinsic Scoring: + - Agents can have intrinsic scorers (TokenLevelScorer, ActionLevelScorer). + - Scores are computed during Agent.act() and flow through AgentStep. + - Token-level scores (e.g., teacher_logps) are attached to SAWItem.attachments. + - Use with make_opd() algorithm for on-policy distillation training. """ rollouts = await self.generate_rollouts( requests=requests, diff --git a/src/ludic/training/batching/synced_batching.py b/src/ludic/training/batching/synced_batching.py index 99c0aef..6fd7405 100644 --- a/src/ludic/training/batching/synced_batching.py +++ b/src/ludic/training/batching/synced_batching.py @@ -18,6 +18,10 @@ class RolloutBatchSource(BatchSource): Note: RolloutEngine now concatenates each agent turn into a single training sample (one SAWItem per turn), rather than emitting per-step samples. + + For on-policy distillation (OPD), configure the Agent with a TokenLevelScorer + (e.g., make_vllm_teacher_scorer). The scorer computes teacher logprobs during + Agent.act() and they flow through to SAWItem.attachments.teacher_logps. """ def __init__( diff --git a/src/ludic/training/distillation.py b/src/ludic/training/distillation.py new file mode 100644 index 0000000..1b471e5 --- /dev/null +++ b/src/ludic/training/distillation.py @@ -0,0 +1,149 @@ +""" +On-Policy Distillation (OPD) support for Ludic. + +This module provides the TeacherClient protocol and implementations for computing +teacher model logprobs on student-sampled tokens. + +Reference: https://thinkingmachines.ai/blog/on-policy-distillation +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, List, Protocol, TYPE_CHECKING + +if TYPE_CHECKING: + import tinker + + +class TeacherClient(Protocol): + """ + Protocol for teacher models that can compute logprobs on given tokens. + + Unlike ChatClient (which generates text), TeacherClient only evaluates + the probability of existing token sequences. This is used for on-policy + distillation where the student samples trajectories and the teacher + provides per-token supervision via reverse KL. + + The key distinction from ChatClient: + - ChatClient.complete_tokens() generates new tokens + - TeacherClient.compute_logprobs() evaluates existing tokens + """ + + async def compute_logprobs( + self, + token_ids: List[int], + ) -> List[float]: + """ + Compute per-token log probabilities for the given sequence. + + Args: + token_ids: Full sequence [prompt + completion] as token IDs. + The teacher evaluates P(token_i | token_0..i-1) for each i. + + Returns: + List of logprobs, one per token position (excluding the first token + which has no prior context). Length = len(token_ids) - 1. + + logprobs[i] = log P(token_ids[i+1] | token_ids[0..i]) + """ + ... + + +@dataclass +class TinkerTeacherClient: + """ + TeacherClient backed by a Tinker SamplingClient. + + Uses Tinker's compute_logprobs_async API which efficiently computes + logprobs in a single forward pass without generating new tokens. + + Example: + >>> import tinker + >>> service_client = tinker.ServiceClient() + >>> sampling_client = service_client.create_sampling_client( + ... base_model="Qwen/Qwen3-32B" + ... ) + >>> teacher = TinkerTeacherClient(sampling_client=sampling_client) + >>> logprobs = await teacher.compute_logprobs([1, 2, 3, 4, 5]) + """ + + sampling_client: Any # tinker.SamplingClient - use Any to avoid hard dep + + async def compute_logprobs(self, token_ids: List[int]) -> List[float]: + import tinker + + model_input = tinker.ModelInput.from_ints(token_ids) + # compute_logprobs_async returns logprobs for all positions including first + # First token has no prior, so we skip it + logprobs = await self.sampling_client.compute_logprobs_async(model_input) + return list(logprobs[1:]) + + +@dataclass +class VLLMTeacherClient: + """ + TeacherClient backed by a vLLM server. + + Uses the OpenAI-compatible /v1/completions endpoint with echo=True + and logprobs enabled to get per-token probabilities without generation. + + Note: This requires the prompt to be passed as token IDs and the server + to support the prompt_logprobs parameter (vLLM extension). + + Example: + >>> teacher = VLLMTeacherClient( + ... base_url="http://localhost:8000", + ... model="Qwen/Qwen3-32B", + ... ) + >>> logprobs = await teacher.compute_logprobs([1, 2, 3, 4, 5]) + """ + + base_url: str + model: str + timeout: float = 60.0 + + async def compute_logprobs(self, token_ids: List[int]) -> List[float]: + import httpx + + # vLLM's /v1/completions endpoint with prompt as token IDs + # echo=True returns logprobs for prompt tokens + # max_tokens=0 prevents any generation + request_body = { + "model": self.model, + "prompt": token_ids, + "max_tokens": 0, + "echo": True, + "logprobs": 1, # Return top-1 logprobs + } + + async with httpx.AsyncClient(timeout=self.timeout) as client: + response = await client.post( + f"{self.base_url}/v1/completions", + json=request_body, + ) + response.raise_for_status() + result = response.json() + + # Extract logprobs from response + # vLLM returns logprobs in choices[0].logprobs.token_logprobs + choice = result["choices"][0] + token_logprobs = choice.get("logprobs", {}).get("token_logprobs", []) + + if token_logprobs is None: + raise ValueError( + "vLLM response did not include token_logprobs. " + "Ensure the server supports logprobs with echo=True." + ) + + # First token has no logprob (or is None), skip it + # The rest should align with token_ids[1:] + logprobs = [] + for lp in token_logprobs[1:]: + if lp is None: + # Some implementations return None for special tokens + logprobs.append(float("-inf")) + else: + logprobs.append(float(lp)) + + return logprobs diff --git a/src/ludic/training/loss.py b/src/ludic/training/loss.py index 18fe078..ccf9104 100644 --- a/src/ludic/training/loss.py +++ b/src/ludic/training/loss.py @@ -737,6 +737,91 @@ def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]] return loss, stats +# --------------------------------------------------------------------------- +# On-Policy Distillation +# --------------------------------------------------------------------------- + + +@dataclass +class ReverseKLLoss: + """ + On-Policy Distillation loss: per-token reverse KL from student to teacher. + + Loss = E[ log π_student(a_t|s_t) - log π_teacher(a_t|s_t) ] + + This pushes the student to match the teacher's distribution on the + student's own samples (on-policy). The loss is minimized when the student + assigns the same probability as the teacher to each token. + + Key insight: Unlike forward KL which mode-covers, reverse KL is mode-seeking. + The student learns to approximate one specific behavior (the teacher's) + rather than spreading across all plausible behaviors. + + Expects: + - batch["input_ids"]: [B, T] tokenized sequences + - batch["action_mask"]: [B, T] mask for completion tokens + - batch["teacher_logps"]: [B, T] per-token teacher logprobs + + Reference: https://thinkingmachines.ai/blog/on-policy-distillation + """ + + coeff: float = 1.0 + length_normalize: bool = False + + @jaxtyped(typechecker=typechecker) + def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]]: + if "teacher_logps" not in batch: + raise KeyError( + "ReverseKLLoss requires batch['teacher_logps']. " + "Ensure teacher logprobs are computed during rollout generation " + "(e.g., via RolloutEngine with teacher_client parameter)." + ) + + input_ids = batch["input_ids"] + action_mask = batch["action_mask"] + teacher_logps = batch["teacher_logps"] + + # Current policy logprobs: [B, T-1] + token_logp = compute_token_logp(logits, input_ids) + + # Shift teacher logps to align with predictions + # teacher_logps[t] is logprob of token t, but we predict token t from t-1 + teacher_logps_shifted = teacher_logps[:, 1:] # [B, T-1] + token_mask = action_mask[:, 1:].to(token_logp.dtype) # [B, T-1] + token_counts = token_mask.sum(dim=-1).clamp(min=1.0) # [B] + + # Reverse KL: log(student) - log(teacher) + # Positive when student is more confident than teacher on a token + # Negative when teacher is more confident + reverse_kl = token_logp - teacher_logps_shifted + + # Masked sum over tokens + per_sample_kl = (reverse_kl * token_mask).sum(dim=-1) # [B] + + if self.length_normalize: + per_sample_kl = per_sample_kl / token_counts + + loss = self.coeff * per_sample_kl.mean() + + # Stats - compute over masked tokens only + mask = token_mask > 0 + if mask.any(): + masked_kl = reverse_kl.masked_select(mask) + kl_mean = masked_kl.mean() + kl_std = masked_kl.std(unbiased=False) + else: + kl_mean = loss.new_zeros(()) + kl_std = loss.new_zeros(()) + + stats: Dict[str, Any] = { + "loss": loss.detach(), + "reverse_kl_mean": kl_mean.detach(), + "reverse_kl_std": kl_std.detach(), + "avg_action_tokens": token_counts.mean().detach(), + } + return loss, stats + + # --------------------------------------------------------------------------- # Composite loss # --------------------------------------------------------------------------- diff --git a/src/ludic/training/scoring.py b/src/ludic/training/scoring.py new file mode 100644 index 0000000..bbebe35 --- /dev/null +++ b/src/ludic/training/scoring.py @@ -0,0 +1,166 @@ +""" +Intrinsic scoring protocols for agents. + +Intrinsic scores are agent-local evaluations of action quality, computed +during rollout generation. They are analogous to the parser (which validates +action syntax) but evaluate action quality instead. + +Two scorer types: +- TokenLevelScorer: Per-token scores (e.g., teacher logprobs for OPD) +- ActionLevelScorer: Scalar per-action scores (e.g., LLM-as-judge) + +Scores are attached to AgentActStep and flow into SAWItem for training. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Protocol, Union, runtime_checkable + + +@runtime_checkable +class TokenLevelScorer(Protocol): + """ + Per-token intrinsic scoring. + + Use cases: + - Teacher logprobs for On-Policy Distillation (OPD) + - Token-level reward models + - Per-token confidence scores + + The scorer receives completion token IDs and returns one score per token. + """ + + name: str + + async def score_tokens(self, token_ids: List[int]) -> List[float]: + """ + Compute per-token scores for the given completion. + + Args: + token_ids: Completion token IDs (not including prompt). + + Returns: + List of scores, one per token. Length must equal len(token_ids). + """ + ... + + +@runtime_checkable +class ActionLevelScorer(Protocol): + """ + Per-action intrinsic scoring (scalar). + + Use cases: + - LLM-as-a-judge + - Verifier models + - Self-consistency scores + + The scorer receives the full context and returns a single scalar. + """ + + name: str + + async def score_action(self, prompt: str, completion: str) -> float: + """ + Compute a scalar score for the action. + + Args: + prompt: The prompt text (rendered messages). + completion: The agent's completion text. + + Returns: + Scalar score for the action. + """ + ... + + +IntrinsicScorer = Union[TokenLevelScorer, ActionLevelScorer] + + +@dataclass +class VLLMTeacherScorer: + """ + TokenLevelScorer backed by vLLM server. + + Computes teacher logprobs via teacher-forced prefill using + the /v1/completions endpoint with echo=True and max_tokens=0. + """ + + base_url: str + model: str + name: str = "teacher_logps" + timeout: float = 60.0 + + async def score_tokens(self, token_ids: List[int]) -> List[float]: + """ + Compute per-token logprobs from teacher model. + + Uses vLLM's echo mode to get logprobs for existing tokens. + """ + import aiohttp + + url = f"{self.base_url}/v1/completions" + payload = { + "model": self.model, + "prompt": token_ids, + "max_tokens": 0, + "echo": True, + "logprobs": 1, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json=payload, + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) as resp: + resp.raise_for_status() + data = await resp.json() + + logprobs_data = data["choices"][0].get("logprobs", {}) + token_logprobs = logprobs_data.get("token_logprobs", []) + + # First token has no prior context, skip it + # Return logprobs for completion tokens only + if token_logprobs and token_logprobs[0] is None: + token_logprobs = token_logprobs[1:] + + return [float(lp) if lp is not None else 0.0 for lp in token_logprobs] + + +def make_vllm_teacher_scorer( + base_url: str, + model: str, + *, + name: str = "teacher_logps", + timeout: float = 60.0, +) -> TokenLevelScorer: + """ + Create a TokenLevelScorer backed by vLLM for teacher logprobs. + + This is used for On-Policy Distillation (OPD) where a teacher model + provides per-token supervision via teacher-forced prefill. + + Args: + base_url: vLLM server URL (e.g., "http://localhost:8001"). + model: Teacher model name. + name: Attachment key for scores (default: "teacher_logps"). + timeout: Request timeout in seconds. + + Returns: + TokenLevelScorer that computes teacher logprobs. + + Example: + >>> teacher = make_vllm_teacher_scorer( + ... base_url="http://localhost:8001", + ... model="Qwen/Qwen3-32B", + ... ) + >>> agent = Agent(client=client, ..., scorers=[teacher]) + """ + return VLLMTeacherScorer( + base_url=base_url, + model=model, + name=name, + timeout=timeout, + ) diff --git a/src/ludic/training/types.py b/src/ludic/training/types.py index 7ffcfa9..b472a3f 100644 --- a/src/ludic/training/types.py +++ b/src/ludic/training/types.py @@ -127,6 +127,27 @@ def __post_init__(self) -> None: raise TypeError("ActorTokenLogps.token_logps must be a List[float].") +@dataclass(frozen=True) +class TeacherLogprobs: + """ + Per-token logprobs from a teacher model, used for on-policy distillation. + + `token_logps[i]` corresponds to the teacher's logprob for the chosen token + at completion position i. This is computed by querying the teacher model + on the student's sampled tokens. + + Reference: https://thinkingmachines.ai/blog/on-policy-distillation + """ + + token_logps: List[float] + + def __post_init__(self) -> None: + if not isinstance(self.token_logps, list) or not all( + isinstance(v, (int, float)) for v in self.token_logps + ): + raise TypeError("TeacherLogprobs.token_logps must be a List[float].") + + @dataclass class SampleAttachments: """ @@ -137,6 +158,7 @@ class SampleAttachments: """ actor_logps: Optional[ActorTokenLogps] = None + teacher_logps: Optional[TeacherLogprobs] = None class HasActorLogps(Protocol): @@ -186,6 +208,10 @@ class SAWItem: def actor_logps(self) -> Optional[ActorTokenLogps]: return self.attachments.actor_logps + @property + def teacher_logps(self) -> Optional[TeacherLogprobs]: + return self.attachments.teacher_logps + @dataclass class SAWBatch: """ diff --git a/src/ludic/types.py b/src/ludic/types.py index 62ebfea..e6dc07b 100644 --- a/src/ludic/types.py +++ b/src/ludic/types.py @@ -184,6 +184,7 @@ class AgentStep: parent_id: Optional[str] = None tool_calls: Optional[List[Dict[str, Any]]] = None tool_results: Optional[List[Dict[str, Any]]] = None + intrinsic_scores: Dict[str, Any] = field(default_factory=dict) @dataclass class EnvironmentStep: From 5e43b8f98e4a616583c8bb953247e74847c41caa Mon Sep 17 00:00:00 2001 From: hallerite Date: Sun, 28 Dec 2025 22:02:15 +0100 Subject: [PATCH 02/17] use fire-and-forget for scoring --- src/ludic/agents/base_agent.py | 22 ++++---- src/ludic/interaction/multi_agent.py | 1 + src/ludic/interaction/single_agent.py | 2 + src/ludic/training/batching/rollout_engine.py | 51 +++++++++++++++++++ src/ludic/types.py | 3 +- 5 files changed, 69 insertions(+), 10 deletions(-) diff --git a/src/ludic/agents/base_agent.py b/src/ludic/agents/base_agent.py index acd722a..6254346 100644 --- a/src/ludic/agents/base_agent.py +++ b/src/ludic/agents/base_agent.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Mapping, TYPE_CHECKING, Union +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Mapping, TYPE_CHECKING, Union import torch @@ -67,6 +67,7 @@ class AgentActStep: tool_calls: Optional[List[Dict[str, Any]]] = None tool_results: Optional[List[Dict[str, Any]]] = None intrinsic_scores: Dict[str, Union[List[float], float]] = field(default_factory=dict) + pending_score_tasks: Dict[str, Awaitable[Union[List[float], float]]] = field(default_factory=dict) @dataclass @@ -259,8 +260,8 @@ async def act( # 5. Parse (format the raw text action) parse_result = self._parser(raw_action) - # 6. Run intrinsic scorers - intrinsic_scores: Dict[str, Union[List[float], float]] = {} + # 6. Fire off intrinsic scorers (don't await - resolve later for parallelism) + pending_score_tasks: Dict[str, Awaitable[Union[List[float], float]]] = {} if self._scorers: from ludic.training.scoring import TokenLevelScorer, ActionLevelScorer @@ -270,13 +271,16 @@ async def act( for scorer in self._scorers: if isinstance(scorer, TokenLevelScorer): - scores = await scorer.score_tokens( - list(token_trace.completion_token_ids) + # Fire and forget - create task but don't await + task = asyncio.create_task( + scorer.score_tokens(list(token_trace.completion_token_ids)) ) - intrinsic_scores[scorer.name] = scores + pending_score_tasks[scorer.name] = task elif isinstance(scorer, ActionLevelScorer): - score = await scorer.score_action(prompt_text, raw_action) - intrinsic_scores[scorer.name] = score + task = asyncio.create_task( + scorer.score_action(prompt_text, raw_action) + ) + pending_score_tasks[scorer.name] = task step = AgentActStep( prompt_messages=messages, @@ -286,7 +290,7 @@ async def act( trace=token_trace, action_target="env", loop_index=0, - intrinsic_scores=intrinsic_scores, + pending_score_tasks=pending_score_tasks, ) return AgentActResult(steps=[step]) diff --git a/src/ludic/interaction/multi_agent.py b/src/ludic/interaction/multi_agent.py index 51af49b..8d21f26 100644 --- a/src/ludic/interaction/multi_agent.py +++ b/src/ludic/interaction/multi_agent.py @@ -173,6 +173,7 @@ async def run( tool_calls=act_step.tool_calls, tool_results=act_step.tool_results, intrinsic_scores=act_step.intrinsic_scores, + pending_score_tasks=act_step.pending_score_tasks, ) collector.add(agent_id, agent_step) step_indices[agent_id] += 1 diff --git a/src/ludic/interaction/single_agent.py b/src/ludic/interaction/single_agent.py index 76b65d9..bc0985a 100644 --- a/src/ludic/interaction/single_agent.py +++ b/src/ludic/interaction/single_agent.py @@ -180,6 +180,7 @@ async def run( tool_calls=act_step.tool_calls, tool_results=act_step.tool_results, intrinsic_scores=act_step.intrinsic_scores, + pending_score_tasks=act_step.pending_score_tasks, ) steps.append(agent_step) turn_agent_step_ids.append(agent_step.id) @@ -257,6 +258,7 @@ async def run( tool_calls=act_step.tool_calls, tool_results=act_step.tool_results, intrinsic_scores=act_step.intrinsic_scores, + pending_score_tasks=act_step.pending_score_tasks, ) steps.append(agent_step) turn_agent_step_ids.append(agent_step.id) diff --git a/src/ludic/training/batching/rollout_engine.py b/src/ludic/training/batching/rollout_engine.py index 98ab229..b8d3d7e 100644 --- a/src/ludic/training/batching/rollout_engine.py +++ b/src/ludic/training/batching/rollout_engine.py @@ -550,6 +550,52 @@ async def _run_one_request( return rollouts + async def _resolve_pending_scores(self, rollouts: List[Rollout]) -> None: + """ + Resolve all pending score tasks from agent steps. + + Intrinsic scorers (e.g., teacher logprobs) fire-and-forget during Agent.act(). + This method awaits all pending tasks and stores results in intrinsic_scores. + + This design allows scoring to run in parallel with: + - Other agent steps within the same rollout + - Other rollouts running concurrently + + After this method returns, all AgentStep.intrinsic_scores are populated. + """ + # Collect all pending tasks with their target locations + pending: List[Tuple[AgentStep, str, asyncio.Task]] = [] + + for rollout in rollouts: + for step in rollout.steps: + if isinstance(step, AgentStep) and step.pending_score_tasks: + for name, task in step.pending_score_tasks.items(): + pending.append((step, name, task)) + + if not pending: + return + + # Await all tasks concurrently + tasks = [t for (_, _, t) in pending] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Store results back in steps + for (step, name, _), result in zip(pending, results): + if isinstance(result, Exception): + # Log but don't fail - scorer errors shouldn't break training + import logging + logging.getLogger(__name__).warning( + f"Scorer '{name}' failed for step {step.index}: {result}" + ) + else: + step.intrinsic_scores[name] = result + + # Clear pending tasks (they're resolved now) + for rollout in rollouts: + for step in rollout.steps: + if isinstance(step, AgentStep): + step.pending_score_tasks = {} + def _append_jsonl(self, rollout: Rollout) -> None: assert self.jsonl_path is not None def _serialize_step(step: Step) -> Dict[str, Any]: @@ -688,6 +734,11 @@ async def generate_batch( timeout_s=timeout_s, concurrency=concurrency, ) + + # Resolve pending score tasks from all agent steps + # This allows scoring to run in parallel with rollout generation + await self._resolve_pending_scores(rollouts) + weights = credit_assigner.compute(rollouts) items_with_lengths: List[Tuple[SAWItem, int, int]] = [] diff --git a/src/ludic/types.py b/src/ludic/types.py index e6dc07b..291540a 100644 --- a/src/ludic/types.py +++ b/src/ludic/types.py @@ -1,6 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field, asdict -from typing import Any, Dict, List, Union, Optional, Literal, Mapping +from typing import Any, Awaitable, Dict, List, Union, Optional, Literal, Mapping import logging import time import uuid @@ -185,6 +185,7 @@ class AgentStep: tool_calls: Optional[List[Dict[str, Any]]] = None tool_results: Optional[List[Dict[str, Any]]] = None intrinsic_scores: Dict[str, Any] = field(default_factory=dict) + pending_score_tasks: Dict[str, Awaitable[Any]] = field(default_factory=dict) @dataclass class EnvironmentStep: From 66205fb4a92064820c50752db4d4acae9a7023b6 Mon Sep 17 00:00:00 2001 From: hallerite Date: Sun, 28 Dec 2025 22:08:41 +0100 Subject: [PATCH 03/17] update --- src/ludic/training/algorithm.py | 36 +++---- src/ludic/training/distillation.py | 149 ----------------------------- 2 files changed, 19 insertions(+), 166 deletions(-) delete mode 100644 src/ludic/training/distillation.py diff --git a/src/ludic/training/algorithm.py b/src/ludic/training/algorithm.py index 9ddb823..cf5ff1b 100644 --- a/src/ludic/training/algorithm.py +++ b/src/ludic/training/algorithm.py @@ -481,10 +481,9 @@ def make_opd( Loss: ReverseKLLoss - minimizes KL(student || teacher) per token. Prerequisites: - - Rollouts must include teacher logprobs via one of: - - RolloutEngine.generate_batch(teacher_client=...) - - External post-processing that populates SAWItem.attachments.teacher_logps - - Collator must extract teacher_logps into batch["teacher_logps"] + - Agent must have a TokenLevelScorer with name="teacher_logps" + - Use make_vllm_teacher_scorer() to create the scorer + - Scorer runs during Agent.act() and scores flow through to training Args: kl_coeff: Coefficient for KL loss. Higher values push the student @@ -495,22 +494,25 @@ def make_opd( Example: ```python - from ludic.training import RolloutBatchSource, Trainer, make_opd - from ludic.training.distillation import TinkerTeacherClient - - # Create teacher client - teacher = TinkerTeacherClient(sampling_client=teacher_sampling_client) - - # Create batch source with teacher - batch_source = RolloutBatchSource( - engine=engine, - make_requests=make_requests_fn, - credit_assigner=make_opd().credit_assigner, - teacher_client=teacher, + from ludic.training import Trainer, make_opd + from ludic.training.scoring import make_vllm_teacher_scorer + from ludic.agent import Agent + + # Create teacher scorer + teacher_scorer = make_vllm_teacher_scorer( + base_url="http://localhost:8001", + model="Qwen/Qwen3-32B", + ) + + # Attach to agent - scores flow through automatically + agent = Agent( + client=client, + ..., + scorers=[teacher_scorer], ) # Train with OPD - trainer = Trainer(model=model, algorithm=make_opd(), ...) + trainer = Trainer(model=model, algo=make_opd(), ...) ``` Reference: https://thinkingmachines.ai/blog/on-policy-distillation diff --git a/src/ludic/training/distillation.py b/src/ludic/training/distillation.py deleted file mode 100644 index 1b471e5..0000000 --- a/src/ludic/training/distillation.py +++ /dev/null @@ -1,149 +0,0 @@ -""" -On-Policy Distillation (OPD) support for Ludic. - -This module provides the TeacherClient protocol and implementations for computing -teacher model logprobs on student-sampled tokens. - -Reference: https://thinkingmachines.ai/blog/on-policy-distillation -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, List, Protocol, TYPE_CHECKING - -if TYPE_CHECKING: - import tinker - - -class TeacherClient(Protocol): - """ - Protocol for teacher models that can compute logprobs on given tokens. - - Unlike ChatClient (which generates text), TeacherClient only evaluates - the probability of existing token sequences. This is used for on-policy - distillation where the student samples trajectories and the teacher - provides per-token supervision via reverse KL. - - The key distinction from ChatClient: - - ChatClient.complete_tokens() generates new tokens - - TeacherClient.compute_logprobs() evaluates existing tokens - """ - - async def compute_logprobs( - self, - token_ids: List[int], - ) -> List[float]: - """ - Compute per-token log probabilities for the given sequence. - - Args: - token_ids: Full sequence [prompt + completion] as token IDs. - The teacher evaluates P(token_i | token_0..i-1) for each i. - - Returns: - List of logprobs, one per token position (excluding the first token - which has no prior context). Length = len(token_ids) - 1. - - logprobs[i] = log P(token_ids[i+1] | token_ids[0..i]) - """ - ... - - -@dataclass -class TinkerTeacherClient: - """ - TeacherClient backed by a Tinker SamplingClient. - - Uses Tinker's compute_logprobs_async API which efficiently computes - logprobs in a single forward pass without generating new tokens. - - Example: - >>> import tinker - >>> service_client = tinker.ServiceClient() - >>> sampling_client = service_client.create_sampling_client( - ... base_model="Qwen/Qwen3-32B" - ... ) - >>> teacher = TinkerTeacherClient(sampling_client=sampling_client) - >>> logprobs = await teacher.compute_logprobs([1, 2, 3, 4, 5]) - """ - - sampling_client: Any # tinker.SamplingClient - use Any to avoid hard dep - - async def compute_logprobs(self, token_ids: List[int]) -> List[float]: - import tinker - - model_input = tinker.ModelInput.from_ints(token_ids) - # compute_logprobs_async returns logprobs for all positions including first - # First token has no prior, so we skip it - logprobs = await self.sampling_client.compute_logprobs_async(model_input) - return list(logprobs[1:]) - - -@dataclass -class VLLMTeacherClient: - """ - TeacherClient backed by a vLLM server. - - Uses the OpenAI-compatible /v1/completions endpoint with echo=True - and logprobs enabled to get per-token probabilities without generation. - - Note: This requires the prompt to be passed as token IDs and the server - to support the prompt_logprobs parameter (vLLM extension). - - Example: - >>> teacher = VLLMTeacherClient( - ... base_url="http://localhost:8000", - ... model="Qwen/Qwen3-32B", - ... ) - >>> logprobs = await teacher.compute_logprobs([1, 2, 3, 4, 5]) - """ - - base_url: str - model: str - timeout: float = 60.0 - - async def compute_logprobs(self, token_ids: List[int]) -> List[float]: - import httpx - - # vLLM's /v1/completions endpoint with prompt as token IDs - # echo=True returns logprobs for prompt tokens - # max_tokens=0 prevents any generation - request_body = { - "model": self.model, - "prompt": token_ids, - "max_tokens": 0, - "echo": True, - "logprobs": 1, # Return top-1 logprobs - } - - async with httpx.AsyncClient(timeout=self.timeout) as client: - response = await client.post( - f"{self.base_url}/v1/completions", - json=request_body, - ) - response.raise_for_status() - result = response.json() - - # Extract logprobs from response - # vLLM returns logprobs in choices[0].logprobs.token_logprobs - choice = result["choices"][0] - token_logprobs = choice.get("logprobs", {}).get("token_logprobs", []) - - if token_logprobs is None: - raise ValueError( - "vLLM response did not include token_logprobs. " - "Ensure the server supports logprobs with echo=True." - ) - - # First token has no logprob (or is None), skip it - # The rest should align with token_ids[1:] - logprobs = [] - for lp in token_logprobs[1:]: - if lp is None: - # Some implementations return None for special tokens - logprobs.append(float("-inf")) - else: - logprobs.append(float(lp)) - - return logprobs From b9ae1ef743bc94336d3959ec7790803d996a6117 Mon Sep 17 00:00:00 2001 From: hallerite Date: Sun, 28 Dec 2025 22:19:50 +0100 Subject: [PATCH 04/17] send prompt token ids too --- src/ludic/agents/base_agent.py | 5 ++++- src/ludic/training/scoring.py | 39 ++++++++++++++++++++++++---------- 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/ludic/agents/base_agent.py b/src/ludic/agents/base_agent.py index 6254346..53c0d44 100644 --- a/src/ludic/agents/base_agent.py +++ b/src/ludic/agents/base_agent.py @@ -273,7 +273,10 @@ async def act( if isinstance(scorer, TokenLevelScorer): # Fire and forget - create task but don't await task = asyncio.create_task( - scorer.score_tokens(list(token_trace.completion_token_ids)) + scorer.score_tokens( + list(token_trace.prompt_token_ids), + list(token_trace.completion_token_ids), + ) ) pending_score_tasks[scorer.name] = task elif isinstance(scorer, ActionLevelScorer): diff --git a/src/ludic/training/scoring.py b/src/ludic/training/scoring.py index bbebe35..1759a8a 100644 --- a/src/ludic/training/scoring.py +++ b/src/ludic/training/scoring.py @@ -28,20 +28,27 @@ class TokenLevelScorer(Protocol): - Token-level reward models - Per-token confidence scores - The scorer receives completion token IDs and returns one score per token. + The scorer receives prompt and completion token IDs, returns one score + per completion token. """ name: str - async def score_tokens(self, token_ids: List[int]) -> List[float]: + async def score_tokens( + self, + prompt_token_ids: List[int], + completion_token_ids: List[int], + ) -> List[float]: """ Compute per-token scores for the given completion. Args: - token_ids: Completion token IDs (not including prompt). + prompt_token_ids: Prompt token IDs (for conditioning context). + completion_token_ids: Completion token IDs to score. Returns: - List of scores, one per token. Length must equal len(token_ids). + List of scores, one per completion token. + Length must equal len(completion_token_ids). """ ... @@ -92,18 +99,28 @@ class VLLMTeacherScorer: name: str = "teacher_logps" timeout: float = 60.0 - async def score_tokens(self, token_ids: List[int]) -> List[float]: + async def score_tokens( + self, + prompt_token_ids: List[int], + completion_token_ids: List[int], + ) -> List[float]: """ Compute per-token logprobs from teacher model. Uses vLLM's echo mode to get logprobs for existing tokens. + The full sequence (prompt + completion) is sent so the teacher + can properly condition on the prompt context. """ import aiohttp + # Send full sequence so teacher conditions on prompt + full_sequence = prompt_token_ids + completion_token_ids + prompt_len = len(prompt_token_ids) + url = f"{self.base_url}/v1/completions" payload = { "model": self.model, - "prompt": token_ids, + "prompt": full_sequence, "max_tokens": 0, "echo": True, "logprobs": 1, @@ -121,12 +138,12 @@ async def score_tokens(self, token_ids: List[int]) -> List[float]: logprobs_data = data["choices"][0].get("logprobs", {}) token_logprobs = logprobs_data.get("token_logprobs", []) - # First token has no prior context, skip it - # Return logprobs for completion tokens only - if token_logprobs and token_logprobs[0] is None: - token_logprobs = token_logprobs[1:] + # Extract only completion token logprobs (skip prompt tokens) + # First token of full sequence has logprob=None, but completion tokens + # all have valid logprobs since they're conditioned on prior tokens + completion_logprobs = token_logprobs[prompt_len:] - return [float(lp) if lp is not None else 0.0 for lp in token_logprobs] + return [float(lp) if lp is not None else 0.0 for lp in completion_logprobs] def make_vllm_teacher_scorer( From 197f51159b1c89b04c0a5504aa332ffea3c0f999 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 00:16:52 +0100 Subject: [PATCH 05/17] add note --- src/ludic/training/scoring.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/src/ludic/training/scoring.py b/src/ludic/training/scoring.py index 1759a8a..90a5e09 100644 --- a/src/ludic/training/scoring.py +++ b/src/ludic/training/scoring.py @@ -92,6 +92,18 @@ class VLLMTeacherScorer: Computes teacher logprobs via teacher-forced prefill using the /v1/completions endpoint with echo=True and max_tokens=0. + + Assumptions: + - Teacher and student use the SAME TOKENIZER. Token IDs from + the student are sent directly to the teacher. If tokenizers + differ, logprobs will be meaningless or the request will fail. + - The full sequence (prompt + completion) fits within the teacher's + context window. If the sequence exceeds the context limit, vLLM + may truncate and return fewer logprobs than completion tokens, + causing a length mismatch error during batch collation. + + For models with different tokenizers, you would need a custom scorer + that re-tokenizes the text with the teacher's tokenizer. """ base_url: str @@ -159,9 +171,14 @@ def make_vllm_teacher_scorer( This is used for On-Policy Distillation (OPD) where a teacher model provides per-token supervision via teacher-forced prefill. + Important: The teacher model MUST use the same tokenizer as the student. + Token IDs are passed directly without re-tokenization. Also ensure the + full sequence (prompt + completion) fits within the teacher's context + window to avoid truncation errors. + Args: base_url: vLLM server URL (e.g., "http://localhost:8001"). - model: Teacher model name. + model: Teacher model name (must share tokenizer with student). name: Attachment key for scores (default: "teacher_logps"). timeout: Request timeout in seconds. From 3a1119eb229fd5c30ee83c73957154ba6545f5bc Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 00:35:03 +0100 Subject: [PATCH 06/17] update example; add docs --- examples/opd/README.md | 99 +++++++++++++++++++++++++++++++++ examples/opd/train_opd_gsm8k.py | 4 +- src/ludic/training/scoring.py | 2 +- 3 files changed, 102 insertions(+), 3 deletions(-) create mode 100644 examples/opd/README.md diff --git a/examples/opd/README.md b/examples/opd/README.md new file mode 100644 index 0000000..c387536 --- /dev/null +++ b/examples/opd/README.md @@ -0,0 +1,99 @@ +# On-Policy Distillation (OPD) Training on GSM8K + +Train a smaller student model using dense per-token supervision from a larger teacher model. + +OPD combines the benefits of: +- **On-policy learning**: Student samples from itself (not teacher demonstrations) +- **Dense supervision**: Per-token feedback via reverse KL divergence (not sparse rewards) + +Reference: https://thinkingmachines.ai/blog/on-policy-distillation + +## Prerequisites + +- At least 2 GPUs (e.g., 2x A100). + - GPU 0: Both vLLM servers (student 0.5B + teacher 7B fit together) + - GPU 1: Training (gradient updates) +- Required extra packages: `datasets`, `math-verify`. + +Install deps (once): +```bash +uv sync --extra examples +``` + +## 1) Start vLLM servers + +You need **two** vLLM servers: one for the student (sampling) and one for the teacher (scoring). For these small models, both can share GPU 0. + +**Important**: Student and teacher must use the **same tokenizer**. The Qwen2.5 family shares tokenizers across sizes, so this works. + +### Terminal 1: Student server (port 8000) +```bash +CUDA_VISIBLE_DEVICES=0 uv run python -m ludic.inference.vllm_server \ + --model Qwen/Qwen2.5-0.5B-Instruct \ + --port 8000 \ + --gpu-memory-utilization 0.4 +``` + +### Terminal 2: Teacher server (port 8001) +```bash +CUDA_VISIBLE_DEVICES=0 uv run python -m ludic.inference.vllm_server \ + --model Qwen/Qwen2.5-7B-Instruct \ + --port 8001 \ + --gpu-memory-utilization 0.5 +``` + +Wait for both servers to report ready before proceeding. + +## 2) Train with OPD + +In a third terminal, run the OPD training script on GPU 1: +```bash +CUDA_VISIBLE_DEVICES=1 PYTHONPATH=. uv run python examples/opd/train_opd_gsm8k.py \ + --student-model Qwen/Qwen2.5-0.5B-Instruct \ + --teacher-model Qwen/Qwen2.5-7B-Instruct \ + --student-port 8000 \ + --teacher-port 8001 \ + --rollouts-per-update 64 \ + --train-steps 100 \ + --micro-token-budget 16384 \ + --max-seq-len 1024 +``` + +### Key flags + +| Flag | Default | Description | +|------|---------|-------------| +| `--student-model` | `Qwen/Qwen3-8B-Base` | Student model (must match vLLM server) | +| `--teacher-model` | `Qwen/Qwen3-32B` | Teacher model (must share tokenizer with student) | +| `--student-port` | 8000 | Student vLLM server port | +| `--teacher-port` | 8001 | Teacher vLLM server port | +| `--kl-coeff` | 1.0 | Coefficient for reverse KL loss | +| `--length-normalize` | False | Normalize loss by sequence length | +| `--rollouts-per-update` | 64 | Rollouts per training step | +| `--concurrency` | 32 | Parallel rollout generation | +| `--limit` | None | Limit training samples (None = use all) | + +### Training logs + +Output includes: +- `train/loss`: Reverse KL loss +- `train/reverse_kl_mean`: Mean per-token KL divergence +- `train/correct_rate`: GSM8K accuracy on training samples +- `train/avg_completion_length`: Average tokens per completion + +Rollouts are written to `opd_rollouts.jsonl`. + +## How OPD works + +1. **Student samples**: The student model generates completions for GSM8K problems +2. **Teacher scores**: The teacher model computes per-token logprobs on the student's samples +3. **Reverse KL loss**: Training minimizes `KL(student || teacher) = log π_student - log π_teacher` + +This pushes the student to assign high probability to tokens the teacher prefers, while staying on-policy (sampling from itself). + +## Tips + +- **Same tokenizer is required**: OPD passes token IDs directly from student to teacher. If tokenizers differ, results will be meaningless. +- **Context window**: Ensure prompt + completion fits in teacher's context window. Truncation causes length mismatches. +- **GPU memory**: With larger models, you may need separate GPUs for student and teacher. Adjust `--gpu-memory-utilization` accordingly. +- **KL coefficient**: Start with `--kl-coeff 1.0`. Increase if student diverges too much from teacher. diff --git a/examples/opd/train_opd_gsm8k.py b/examples/opd/train_opd_gsm8k.py index fa650c0..fd54da2 100644 --- a/examples/opd/train_opd_gsm8k.py +++ b/examples/opd/train_opd_gsm8k.py @@ -96,9 +96,9 @@ def main(): parser = argparse.ArgumentParser(description="OPD training on GSM8K") # Model configuration - parser.add_argument("--student-model", default="Qwen/Qwen3-8B-Base", + parser.add_argument("--student-model", default="Qwen/Qwen2.5-0.5B-Instruct", help="Student model name/path") - parser.add_argument("--teacher-model", default="Qwen/Qwen3-32B", + parser.add_argument("--teacher-model", default="Qwen/Qwen2.5-7B-Instruct", help="Teacher model name/path") # vLLM server configuration diff --git a/src/ludic/training/scoring.py b/src/ludic/training/scoring.py index 90a5e09..2e712f3 100644 --- a/src/ludic/training/scoring.py +++ b/src/ludic/training/scoring.py @@ -188,7 +188,7 @@ def make_vllm_teacher_scorer( Example: >>> teacher = make_vllm_teacher_scorer( ... base_url="http://localhost:8001", - ... model="Qwen/Qwen3-32B", + ... model="Qwen/Qwen2.5-7B-Instruct", ... ) >>> agent = Agent(client=client, ..., scorers=[teacher]) """ From 61d851da92047aa0ebdcf1fe43785cfd384d24e8 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 00:56:04 +0100 Subject: [PATCH 07/17] add loggers --- examples/opd/train_opd_gsm8k.py | 42 +++++++++++++++++++++++++++++++-- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/examples/opd/train_opd_gsm8k.py b/examples/opd/train_opd_gsm8k.py index fd54da2..74eccee 100644 --- a/examples/opd/train_opd_gsm8k.py +++ b/examples/opd/train_opd_gsm8k.py @@ -49,7 +49,7 @@ make_opd, RequestsExhausted, ) -from ludic.training import Reducer, PrintLogger, default_reducers +from ludic.training import Reducer, PrintLogger, RichLiveLogger, TeeLogger, WandbLogger, default_reducers from ludic.training.scoring import make_vllm_teacher_scorer # Try to import environments @@ -138,6 +138,10 @@ def main(): parser.add_argument("--system-prompt", type=str, default="First, think step by step. Then put your final answer inside \\boxed{...}.") + # Logging + parser.add_argument("--logger", type=str, default="rich", + help="Comma-separated loggers: rich, print, wandb, none.") + args = parser.parse_args() # Load training data @@ -281,7 +285,41 @@ def protocol_factory(): "train/avg_completion_length", "train/num_samples", ] - train_logger = PrintLogger(prefix="[opd]", keys=logger_keys, precision=4) + + train_logger = None + raw_logger = args.logger or "rich" + logger_tokens = [tok.strip().lower() for tok in raw_logger.replace("+", ",").split(",") if tok.strip()] + valid_loggers = {"rich", "print", "wandb", "none"} + unknown = [tok for tok in logger_tokens if tok not in valid_loggers] + if unknown: + raise SystemExit(f"Unknown logger(s): {unknown}. Valid: {sorted(valid_loggers)}") + if "none" in logger_tokens: + logger_tokens = ["none"] + + console_logger = None + if "print" in logger_tokens: + console_logger = PrintLogger(prefix="[opd]", keys=logger_keys, precision=4) + elif "rich" in logger_tokens: + import sys + if not sys.stdout.isatty(): + console_logger = PrintLogger(prefix="[opd]", keys=logger_keys, precision=4) + else: + console_logger = RichLiveLogger( + keys=logger_keys, + spark_key="train/reverse_kl_mean", + history=100, + precision=4, + ) + + wandb_logger = None + if "wandb" in logger_tokens: + wandb_logger = WandbLogger(config=dict(vars(args))) + + if logger_tokens != ["none"]: + if console_logger and wandb_logger: + train_logger = TeeLogger(console_logger, wandb_logger) + else: + train_logger = console_logger or wandb_logger # Create trainer trainer = Trainer( From c4270f4c5e767cbdbc7432b268e6435149670366 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 01:04:43 +0100 Subject: [PATCH 08/17] add eval --- examples/opd/README.md | 5 ++- examples/opd/train_opd_gsm8k.py | 66 +++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/examples/opd/README.md b/examples/opd/README.md index c387536..1f97910 100644 --- a/examples/opd/README.md +++ b/examples/opd/README.md @@ -63,8 +63,8 @@ CUDA_VISIBLE_DEVICES=1 PYTHONPATH=. uv run python examples/opd/train_opd_gsm8k.p | Flag | Default | Description | |------|---------|-------------| -| `--student-model` | `Qwen/Qwen3-8B-Base` | Student model (must match vLLM server) | -| `--teacher-model` | `Qwen/Qwen3-32B` | Teacher model (must share tokenizer with student) | +| `--student-model` | `Qwen/Qwen2.5-0.5B-Instruct` | Student model (must match vLLM server) | +| `--teacher-model` | `Qwen/Qwen2.5-7B-Instruct` | Teacher model (must share tokenizer with student) | | `--student-port` | 8000 | Student vLLM server port | | `--teacher-port` | 8001 | Teacher vLLM server port | | `--kl-coeff` | 1.0 | Coefficient for reverse KL loss | @@ -72,6 +72,7 @@ CUDA_VISIBLE_DEVICES=1 PYTHONPATH=. uv run python examples/opd/train_opd_gsm8k.p | `--rollouts-per-update` | 64 | Rollouts per training step | | `--concurrency` | 32 | Parallel rollout generation | | `--limit` | None | Limit training samples (None = use all) | +| `--logger` | `rich` | Loggers: rich, print, wandb, none (comma-separated) | ### Training logs diff --git a/examples/opd/train_opd_gsm8k.py b/examples/opd/train_opd_gsm8k.py index 74eccee..c0fb902 100644 --- a/examples/opd/train_opd_gsm8k.py +++ b/examples/opd/train_opd_gsm8k.py @@ -48,8 +48,12 @@ make_dataset_queue_requests_fn, make_opd, RequestsExhausted, + RolloutRequest, + EnvSpec, + ProtocolSpec, ) from ludic.training import Reducer, PrintLogger, RichLiveLogger, TeeLogger, WandbLogger, default_reducers +from ludic.eval import EngineEvaluator from ludic.training.scoring import make_vllm_teacher_scorer # Try to import environments @@ -142,6 +146,17 @@ def main(): parser.add_argument("--logger", type=str, default="rich", help="Comma-separated loggers: rich, print, wandb, none.") + # Evaluation + parser.add_argument("--eval-every", type=int, default=10, + help="Eval every N train steps.") + parser.add_argument("--eval-before-start", action="store_true", default=True, + help="Run eval once before training begins.") + parser.add_argument("--eval-limit", type=int, default=1000, + help="Number of test samples for eval.") + parser.add_argument("--eval-concurrency", type=int, default=64) + parser.add_argument("--eval-temperature", type=float, default=0.0, + help="Sampling temperature for eval passes.") + args = parser.parse_args() # Load training data @@ -151,6 +166,11 @@ def main(): raise SystemExit("No GSM8K samples loaded.") print(f"Loaded {len(train_samples)} training samples") + # Load eval data + eval_samples = load_gsm8k("test", args.eval_limit) if args.eval_limit else [] + if eval_samples: + print(f"Loaded {len(eval_samples)} eval samples") + # Create sample queue samples_q: queue.Queue = queue.Queue() for idx, s in enumerate(train_samples): @@ -264,6 +284,10 @@ def protocol_factory(): micro_token_budget=args.micro_token_budget, max_grad_norm=0.5, pad_token_id=tokenizer, + eval_at_start=bool(args.eval_before_start and eval_samples), + eval_every_n_steps=(args.eval_every if args.eval_every and args.eval_every > 0 and eval_samples else None), + eval_concurrency=args.eval_concurrency, + eval_max_steps=1, ) # Reducers for logging @@ -276,6 +300,13 @@ def protocol_factory(): ), } + # Eval reducers + eval_reducers = { + "accuracy": Reducer(kind="count_true", source="correct", normalize_by="samples", as_percent=True), + "parse_error_rate": Reducer(kind="count_true", source="parse_error", normalize_by="samples", as_percent=True), + "avg_completion_tokens": Reducer(kind="mean", source="completion_length"), + } + # Logger logger_keys = [ "train/loss", @@ -284,6 +315,9 @@ def protocol_factory(): "train/correct_rate", "train/avg_completion_length", "train/num_samples", + "eval/accuracy", + "eval/parse_error_rate", + "eval/avg_completion_tokens", ] train_logger = None @@ -331,6 +365,38 @@ def protocol_factory(): cfg=cfg, train_logger=train_logger, reducers=reducers, + evaluator=( + None + if not eval_samples + else EngineEvaluator( + engine=RolloutEngine(env_registry=env_registry, protocol_registry=protocol_registry), + requests_fn=lambda: [ + RolloutRequest( + env=EnvSpec( + kind="gsm8k", + kwargs={"sample": sample}, + ), + protocol=ProtocolSpec(kind="single_agent"), + env_seed=idx, + sampling_seed=idx, + inference=InferenceSpec( + sampling=SamplingParams( + temperature=args.eval_temperature, + max_tokens=args.max_completion_tokens, + ), + return_=ReturnSpec.for_eval(return_token_ids=True), + ), + num_episodes=1, + meta={"eval_sample_index": idx, "question_id": sample.get("id", idx)}, + ) + for idx, sample in enumerate(eval_samples) + ], + reducers=eval_reducers, + max_steps=1, + timeout_s=cfg.eval_timeout_s, + concurrency=cfg.eval_concurrency, + ) + ), ) # Train From d96657acd1e48215e5528fc0ebcb7d5095ffecbe Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 01:25:57 +0100 Subject: [PATCH 09/17] update --- examples/opd/README.md | 5 + examples/opd/train_opd_gsm8k.py | 163 +++++++++++++++++--------------- 2 files changed, 91 insertions(+), 77 deletions(-) diff --git a/examples/opd/README.md b/examples/opd/README.md index 1f97910..864a84f 100644 --- a/examples/opd/README.md +++ b/examples/opd/README.md @@ -73,6 +73,9 @@ CUDA_VISIBLE_DEVICES=1 PYTHONPATH=. uv run python examples/opd/train_opd_gsm8k.p | `--concurrency` | 32 | Parallel rollout generation | | `--limit` | None | Limit training samples (None = use all) | | `--logger` | `rich` | Loggers: rich, print, wandb, none (comma-separated) | +| `--eval-every` | 10 | Eval every N train steps | +| `--eval-limit` | 1000 | Number of test samples for eval | +| `--eval-temperature` | 0.0 | Sampling temperature for eval (greedy) | ### Training logs @@ -81,6 +84,8 @@ Output includes: - `train/reverse_kl_mean`: Mean per-token KL divergence - `train/correct_rate`: GSM8K accuracy on training samples - `train/avg_completion_length`: Average tokens per completion +- `eval/accuracy`: GSM8K accuracy on test set +- `eval/parse_error_rate`: Parse error rate on test set Rollouts are written to `opd_rollouts.jsonl`. diff --git a/examples/opd/train_opd_gsm8k.py b/examples/opd/train_opd_gsm8k.py index c0fb902..978fcdb 100644 --- a/examples/opd/train_opd_gsm8k.py +++ b/examples/opd/train_opd_gsm8k.py @@ -17,8 +17,8 @@ Usage: python train_opd_gsm8k.py \ - --student-model Qwen/Qwen3-8B-Base \ - --teacher-model Qwen/Qwen3-32B \ + --student-model Qwen/Qwen2.5-0.5B-Instruct \ + --teacher-model Qwen/Qwen2.5-7B-Instruct \ --limit 1000 Requirements: @@ -28,6 +28,8 @@ from __future__ import annotations import argparse +import os +import sys import queue from typing import List, Dict, Any @@ -37,14 +39,17 @@ from ludic.agent import Agent from ludic.context import FullDialog -from ludic.inference import InferenceSpec, SamplingParams, ReturnSpec, HFChatTemplate +from ludic.inference import VLLMChatClient, InferenceSpec, SamplingParams, ReturnSpec, HFChatTemplate from ludic.interaction import SingleAgentSyncProtocol from ludic.parsers import boxed_parser +from ludic.distributed.adapters import create_vllm_publisher +from ludic.eval import EngineEvaluator from ludic.training import ( RolloutEngine, RolloutBatchSource, Trainer, TrainerConfig, + CheckpointConfig, make_dataset_queue_requests_fn, make_opd, RequestsExhausted, @@ -52,31 +57,9 @@ EnvSpec, ProtocolSpec, ) -from ludic.training import Reducer, PrintLogger, RichLiveLogger, TeeLogger, WandbLogger, default_reducers -from ludic.eval import EngineEvaluator +from ludic.training import Reducer, RichLiveLogger, PrintLogger, TeeLogger, WandbLogger, default_reducers from ludic.training.scoring import make_vllm_teacher_scorer - -# Try to import environments -try: - from environments.gsm8k import GSM8KEnv -except ImportError: - # Fallback: define a minimal GSM8K env - from ludic.envs import DatasetQAEnv - - class GSM8KEnv(DatasetQAEnv): - def __init__(self, sample: Dict[str, Any], system_prompt: str = ""): - super().__init__( - question=sample["question"], - ground_truth=self._extract_answer(sample["answer"]), - system_prompt=system_prompt or "Solve the following problem step by step. Put your final answer in \\boxed{}.", - ) - - @staticmethod - def _extract_answer(answer_text: str) -> str: - # GSM8K answers have format "...\n#### answer" - if "####" in answer_text: - return answer_text.split("####")[-1].strip() - return answer_text.strip() +from environments.gsm8k import GSM8KEnv def load_gsm8k(split: str, limit: int | None) -> List[Dict[str, Any]]: @@ -120,16 +103,16 @@ def main(): parser.add_argument("--rollouts-per-update", type=int, default=64, help="Number of rollouts per training step") parser.add_argument("--train-steps", type=int, default=100, - help="Number of training steps") - parser.add_argument("--max-seq-len", type=int, default=2048, - help="Max sequence length") - parser.add_argument("--micro-token-budget", type=int, default=32768, + help="Number of training steps; 0 = run until samples exhausted") + parser.add_argument("--max-seq-len", type=int, default=1024, + help="Max tokens per sample") + parser.add_argument("--micro-token-budget", type=int, default=16384, help="Max padded tokens per micro-batch") - parser.add_argument("--max-completion-tokens", type=int, default=1024, + parser.add_argument("--max-completion-tokens", type=int, default=512, help="Max completion tokens per rollout") parser.add_argument("--temperature", type=float, default=1.0, - help="Sampling temperature") - parser.add_argument("--concurrency", type=int, default=32, + help="Sampling temperature for training rollouts") + parser.add_argument("--concurrency", type=int, default=64, help="Rollout concurrency") # OPD-specific configuration @@ -140,9 +123,11 @@ def main(): # System prompt parser.add_argument("--system-prompt", type=str, - default="First, think step by step. Then put your final answer inside \\boxed{...}.") + default="First, think step by step. Then put your final answer inside \\boxed{...}.", + help="System prompt for GSM8K env; set to '' to use the model default.") # Logging + parser.add_argument("--rollout-log", type=str, default="opd_rollouts.jsonl") parser.add_argument("--logger", type=str, default="rich", help="Comma-separated loggers: rich, print, wandb, none.") @@ -157,8 +142,24 @@ def main(): parser.add_argument("--eval-temperature", type=float, default=0.0, help="Sampling temperature for eval passes.") + # Checkpointing + parser.add_argument("--final-save", action="store_true", + help="Save a final checkpoint after training completes.") + args = parser.parse_args() + # Validation + if args.rollouts_per_update <= 0: + raise ValueError("--rollouts-per-update must be > 0.") + if args.max_completion_tokens > args.max_seq_len: + raise ValueError("--max-completion-tokens must be <= --max-seq-len.") + + # Setup rollout log path + rollout_log_path = os.path.abspath(args.rollout_log) + os.makedirs(os.path.dirname(rollout_log_path) or ".", exist_ok=True) + # Touch the file so tailing works even before the first rollout is written + open(rollout_log_path, "a", encoding="utf-8").close() + # Load training data print(f"Loading GSM8K {args.split} split...") train_samples = load_gsm8k(args.split, args.limit) @@ -179,24 +180,17 @@ def main(): # Load tokenizer and model print(f"Loading student model: {args.student_model}") tokenizer = AutoTokenizer.from_pretrained(args.student_model) - model = AutoModelForCausalLM.from_pretrained( - args.student_model, - torch_dtype=torch.bfloat16, - ) - device = "cuda" if torch.cuda.is_available() else "cpu" - model.to(device) - - # Create vLLM clients - from ludic.inference import VLLMChatClient - from ludic.distributed.adapters import create_vllm_publisher + model = AutoModelForCausalLM.from_pretrained(args.student_model, dtype=torch.bfloat16) + model.to("cuda" if torch.cuda.is_available() else "cpu") - # Student client for sampling + # Create vLLM client for student client = VLLMChatClient( host=args.student_host, port=args.student_port, enable_weight_updates=True, ) publisher = create_vllm_publisher(client) + chat_template = HFChatTemplate(tokenizer) # Teacher scorer - computes per-token logprobs during Agent.act() teacher_scorer = make_vllm_teacher_scorer( @@ -204,17 +198,11 @@ def main(): model=args.teacher_model, ) - chat_template = HFChatTemplate(tokenizer) - - # Environment and protocol registries + # Registries env_registry = { - "gsm8k": lambda sample: GSM8KEnv( - sample=sample, - system_prompt=args.system_prompt, - ) + "gsm8k": lambda sample: GSM8KEnv(sample=sample, system_prompt=args.system_prompt) } - # Agent with teacher scorer - scores flow through to training def protocol_factory(): return SingleAgentSyncProtocol( agent=Agent( @@ -229,21 +217,19 @@ def protocol_factory(): protocol_registry = {"single_agent": protocol_factory} - # Create OPD algorithm + # Algorithm (OPD: reverse KL loss with constant credit) algo = make_opd( kl_coeff=args.kl_coeff, length_normalize=args.length_normalize, name="opd", ) - # Create rollout engine + # Engine + batch source engine = RolloutEngine( env_registry=env_registry, protocol_registry=protocol_registry, - jsonl_path="opd_rollouts.jsonl", + jsonl_path=rollout_log_path, ) - - # Create inference spec train_inference = InferenceSpec( sampling=SamplingParams( temperature=args.temperature, @@ -251,8 +237,6 @@ def protocol_factory(): ), return_=ReturnSpec.for_rl(top_logprobs_k=1), ) - - # Create requests function requests_fn = make_dataset_queue_requests_fn( samples_q, batch_size=args.rollouts_per_update, @@ -267,8 +251,6 @@ def protocol_factory(): env_seed_fn=lambda idx, _sample: idx, sampling_seed_fn=lambda idx, _sample: idx, ) - - # Create batch source (no teacher_client needed - it's in the Agent!) batch_source = RolloutBatchSource( orchestrator=engine, credit_assigner=algo.credit_assigner, @@ -279,7 +261,7 @@ def protocol_factory(): # Trainer config cfg = TrainerConfig( - model_device=device, + model_device="cuda" if torch.cuda.is_available() else "cpu", max_seq_len=args.max_seq_len, micro_token_budget=args.micro_token_budget, max_grad_norm=0.5, @@ -290,34 +272,47 @@ def protocol_factory(): eval_max_steps=1, ) - # Reducers for logging + # Checkpoint config + checkpoint_cfg = CheckpointConfig( + output_dir="checkpoints_opd", + every_n_steps=25, + max_to_keep=2, + save_optimizer=True, + ) + + # Reducers reducers = { - **default_reducers(), "correct_rate": Reducer( kind="count_true", source="correct", normalize_by="rollouts", ), + "parse_err_rate": Reducer( + kind="count_true", + source="parse_error", + normalize_by="samples", + ), + "total_completion_tokens": Reducer( + kind="sum", + source="completion_length", + ), } + reducers = {**default_reducers(), **reducers} - # Eval reducers - eval_reducers = { - "accuracy": Reducer(kind="count_true", source="correct", normalize_by="samples", as_percent=True), - "parse_error_rate": Reducer(kind="count_true", source="parse_error", normalize_by="samples", as_percent=True), - "avg_completion_tokens": Reducer(kind="mean", source="completion_length"), - } - - # Logger + # Logger keys logger_keys = [ "train/loss", "train/reverse_kl_mean", "train/avg_total_reward", "train/correct_rate", + "train/parse_err_rate", "train/avg_completion_length", - "train/num_samples", + "train/total_completion_tokens", "eval/accuracy", "eval/parse_error_rate", "eval/avg_completion_tokens", + "train/target_rollouts", + "train/num_samples", ] train_logger = None @@ -334,13 +329,12 @@ def protocol_factory(): if "print" in logger_tokens: console_logger = PrintLogger(prefix="[opd]", keys=logger_keys, precision=4) elif "rich" in logger_tokens: - import sys if not sys.stdout.isatty(): console_logger = PrintLogger(prefix="[opd]", keys=logger_keys, precision=4) else: console_logger = RichLiveLogger( keys=logger_keys, - spark_key="train/reverse_kl_mean", + spark_key="train/avg_total_reward", history=100, precision=4, ) @@ -355,6 +349,13 @@ def protocol_factory(): else: train_logger = console_logger or wandb_logger + # Eval reducers + eval_reducers = { + "accuracy": Reducer(kind="count_true", source="correct", normalize_by="samples", as_percent=True), + "parse_error_rate": Reducer(kind="count_true", source="parse_error", normalize_by="samples", as_percent=True), + "avg_completion_tokens": Reducer(kind="mean", source="completion_length"), + } + # Create trainer trainer = Trainer( model=model, @@ -363,6 +364,7 @@ def protocol_factory(): publisher=publisher, enable_gradient_checkpointing=True, cfg=cfg, + checkpoint_config=checkpoint_cfg, train_logger=train_logger, reducers=reducers, evaluator=( @@ -411,6 +413,13 @@ def protocol_factory(): except RequestsExhausted: print("No more training samples; stopping.") + if args.final_save: + try: + ckpt_path = trainer.save_checkpoint(metadata={"final": True}) + print(f"Final checkpoint saved to: {ckpt_path}") + except RuntimeError: + pass # No checkpointer configured + print("\nTraining complete!") From a1d0bc6c5fb6a5d3ac4e94e99c58812ebcf97306 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 01:58:42 +0100 Subject: [PATCH 10/17] update --- examples/opd/README.md | 37 +++++++++----- examples/opd/train_opd_gsm8k.py | 90 ++++++++++++++++++++++++--------- src/ludic/training/algorithm.py | 6 +-- 3 files changed, 91 insertions(+), 42 deletions(-) diff --git a/examples/opd/README.md b/examples/opd/README.md index 864a84f..4d5cfe6 100644 --- a/examples/opd/README.md +++ b/examples/opd/README.md @@ -1,10 +1,14 @@ -# On-Policy Distillation (OPD) Training on GSM8K +# GSPO + OPD Hybrid Training on GSM8K -Train a smaller student model using dense per-token supervision from a larger teacher model. +Train a smaller student model using both task rewards and dense per-token supervision from a larger teacher model. -OPD combines the benefits of: -- **On-policy learning**: Student samples from itself (not teacher demonstrations) -- **Dense supervision**: Per-token feedback via reverse KL divergence (not sparse rewards) +This hybrid approach combines: +- **GSPO (Group-Sorted Policy Optimization)**: Task rewards from GSM8K correctness with group-normalized advantages +- **OPD (On-Policy Distillation)**: Dense per-token feedback via reverse KL divergence from teacher + +The composite loss balances: +1. **Task-specific learning**: Sparse but grounded rewards from environment +2. **Distribution matching**: Dense per-token guidance from teacher Reference: https://thinkingmachines.ai/blog/on-policy-distillation @@ -67,9 +71,9 @@ CUDA_VISIBLE_DEVICES=1 PYTHONPATH=. uv run python examples/opd/train_opd_gsm8k.p | `--teacher-model` | `Qwen/Qwen2.5-7B-Instruct` | Teacher model (must share tokenizer with student) | | `--student-port` | 8000 | Student vLLM server port | | `--teacher-port` | 8001 | Teacher vLLM server port | -| `--kl-coeff` | 1.0 | Coefficient for reverse KL loss | -| `--length-normalize` | False | Normalize loss by sequence length | -| `--rollouts-per-update` | 64 | Rollouts per training step | +| `--kl-coeff` | 1.0 | Coefficient for reverse KL loss term | +| `--rollouts-per-update` | 256 | Total rollouts per training step | +| `--group-size` | 8 | Group size for GSPO advantages | | `--concurrency` | 32 | Parallel rollout generation | | `--limit` | None | Limit training samples (None = use all) | | `--logger` | `rich` | Loggers: rich, print, wandb, none (comma-separated) | @@ -80,8 +84,10 @@ CUDA_VISIBLE_DEVICES=1 PYTHONPATH=. uv run python examples/opd/train_opd_gsm8k.p ### Training logs Output includes: -- `train/loss`: Reverse KL loss -- `train/reverse_kl_mean`: Mean per-token KL divergence +- `train/loss`: Combined loss (GSPO + KL) +- `train/gspo/loss`: GSPO policy gradient loss +- `train/kl/loss`: Reverse KL loss +- `train/kl/reverse_kl_mean`: Mean per-token KL divergence - `train/correct_rate`: GSM8K accuracy on training samples - `train/avg_completion_length`: Average tokens per completion - `eval/accuracy`: GSM8K accuracy on test set @@ -89,13 +95,16 @@ Output includes: Rollouts are written to `opd_rollouts.jsonl`. -## How OPD works +## How GSPO + OPD works 1. **Student samples**: The student model generates completions for GSM8K problems -2. **Teacher scores**: The teacher model computes per-token logprobs on the student's samples -3. **Reverse KL loss**: Training minimizes `KL(student || teacher) = log π_student - log π_teacher` +2. **Environment rewards**: Each completion is graded for correctness (sparse reward) +3. **Teacher scores**: The teacher model computes per-token logprobs on the student's samples +4. **Composite loss**: Training uses two objectives: + - **GSPO**: Policy gradient with group-normalized advantages from task rewards + - **Reverse KL**: Minimizes `KL(student || teacher) = log π_student - log π_teacher` -This pushes the student to assign high probability to tokens the teacher prefers, while staying on-policy (sampling from itself). +This gives the student task-specific learning from environment feedback while also pushing it to match the teacher's token distribution. ## Tips diff --git a/examples/opd/train_opd_gsm8k.py b/examples/opd/train_opd_gsm8k.py index 978fcdb..59fb589 100644 --- a/examples/opd/train_opd_gsm8k.py +++ b/examples/opd/train_opd_gsm8k.py @@ -1,14 +1,17 @@ """ -On-Policy Distillation (OPD) training on GSM8K using vLLM. +GSPO + OPD hybrid training on GSM8K using vLLM. -This example demonstrates on-policy distillation where: - - A student model samples trajectories - - A teacher model provides per-token logprobs as dense supervision - - Training minimizes reverse KL divergence: KL(student || teacher) +This example combines: + - GSPO (Group-Sorted Policy Optimization): Task rewards from GSM8K correctness + - OPD (On-Policy Distillation): Dense per-token supervision from teacher -This combines the benefits of: - - On-policy learning (student samples from itself) - - Dense supervision (per-token feedback, not sparse rewards) +The hybrid approach uses a composite loss: + 1. ClippedSurrogateLoss (GSPO): Policy gradient with group-normalized advantages + 2. ReverseKLLoss (OPD): KL(student || teacher) = log π_student - log π_teacher + +This gives you both: + - Task-specific learning from environment rewards (sparse but grounded) + - Distribution matching from teacher (dense per-token feedback) The key insight: teacher logprobs are an *intrinsic scorer* attached to the Agent. The scorer runs during Agent.act() and scores flow through to training. @@ -51,14 +54,17 @@ TrainerConfig, CheckpointConfig, make_dataset_queue_requests_fn, - make_opd, RequestsExhausted, RolloutRequest, EnvSpec, ProtocolSpec, + RLAlgorithm, ) from ludic.training import Reducer, RichLiveLogger, PrintLogger, TeeLogger, WandbLogger, default_reducers from ludic.training.scoring import make_vllm_teacher_scorer +from ludic.training.credit_assignment import GroupNormalizedReturn +from ludic.training.loss import ClippedSurrogateLoss, ReverseKLLoss, CompositeLoss, LossTerm +from ludic.training.algorithm import validate_actor_logps from environments.gsm8k import GSM8KEnv @@ -100,9 +106,11 @@ def main(): help="Limit training samples (None = use all)") # Training configuration - parser.add_argument("--rollouts-per-update", type=int, default=64, - help="Number of rollouts per training step") - parser.add_argument("--train-steps", type=int, default=100, + parser.add_argument("--rollouts-per-update", type=int, default=256, + help="Total rollouts per update (must be divisible by --group-size)") + parser.add_argument("--group-size", type=int, default=8, + help="Group size for grouped advantages") + parser.add_argument("--train-steps", type=int, default=20, help="Number of training steps; 0 = run until samples exhausted") parser.add_argument("--max-seq-len", type=int, default=1024, help="Max tokens per sample") @@ -110,16 +118,14 @@ def main(): help="Max padded tokens per micro-batch") parser.add_argument("--max-completion-tokens", type=int, default=512, help="Max completion tokens per rollout") - parser.add_argument("--temperature", type=float, default=1.0, + parser.add_argument("--train-temperature", type=float, default=1.0, help="Sampling temperature for training rollouts") parser.add_argument("--concurrency", type=int, default=64, help="Rollout concurrency") - # OPD-specific configuration + # OPD-specific configuration (hybrid GSPO + KL) parser.add_argument("--kl-coeff", type=float, default=1.0, - help="Coefficient for reverse KL loss") - parser.add_argument("--length-normalize", action="store_true", - help="Normalize loss by sequence length") + help="Coefficient for reverse KL loss term") # System prompt parser.add_argument("--system-prompt", type=str, @@ -151,6 +157,8 @@ def main(): # Validation if args.rollouts_per_update <= 0: raise ValueError("--rollouts-per-update must be > 0.") + if args.rollouts_per_update % args.group_size != 0: + raise ValueError("--rollouts-per-update must be divisible by --group-size.") if args.max_completion_tokens > args.max_seq_len: raise ValueError("--max-completion-tokens must be <= --max-seq-len.") @@ -217,11 +225,38 @@ def protocol_factory(): protocol_registry = {"single_agent": protocol_factory} - # Algorithm (OPD: reverse KL loss with constant credit) - algo = make_opd( - kl_coeff=args.kl_coeff, - length_normalize=args.length_normalize, - name="opd", + # Algorithm: GSPO (task rewards) + OPD (teacher KL) hybrid + # Credit assignment from GSPO: group-normalized returns + credit_assigner = GroupNormalizedReturn( + group_size=args.group_size, + normalize_adv=True, + positive_only=False, + ) + # Composite loss: PPO-style clipped surrogate + reverse KL + loss = CompositeLoss(terms=[ + LossTerm( + name="gspo", + loss=ClippedSurrogateLoss( + clip_eps_low=3e-4, + clip_eps_high=4e-4, + length_normalize=True, + ), + weight=1.0, + ), + LossTerm( + name="kl", + loss=ReverseKLLoss( + coeff=1.0, + length_normalize=True, + ), + weight=args.kl_coeff, + ), + ]) + algo = RLAlgorithm( + name="gspo_opd", + credit_assigner=credit_assigner, + loss=loss, + preprocess=validate_actor_logps, ) # Engine + batch source @@ -232,14 +267,16 @@ def protocol_factory(): ) train_inference = InferenceSpec( sampling=SamplingParams( - temperature=args.temperature, + temperature=args.train_temperature, max_tokens=args.max_completion_tokens, ), + # Ask vLLM for token IDs + chosen-token logprobs for importance sampling return_=ReturnSpec.for_rl(top_logprobs_k=1), ) + base_requests = args.rollouts_per_update // args.group_size requests_fn = make_dataset_queue_requests_fn( samples_q, - batch_size=args.rollouts_per_update, + batch_size=base_requests, env_kind="gsm8k", protocol_kind="single_agent", inference=train_inference, @@ -250,6 +287,7 @@ def protocol_factory(): }, env_seed_fn=lambda idx, _sample: idx, sampling_seed_fn=lambda idx, _sample: idx, + group_size=args.group_size, ) batch_source = RolloutBatchSource( orchestrator=engine, @@ -302,7 +340,9 @@ def protocol_factory(): # Logger keys logger_keys = [ "train/loss", - "train/reverse_kl_mean", + "train/gspo/loss", + "train/kl/loss", + "train/kl/reverse_kl_mean", "train/avg_total_reward", "train/correct_rate", "train/parse_err_rate", diff --git a/src/ludic/training/algorithm.py b/src/ludic/training/algorithm.py index cf5ff1b..f82f040 100644 --- a/src/ludic/training/algorithm.py +++ b/src/ludic/training/algorithm.py @@ -461,7 +461,7 @@ def make_sft( def make_opd( *, kl_coeff: float = 1.0, - length_normalize: bool = False, + length_normalize: bool = True, name: str = "opd", ) -> RLAlgorithm: """ @@ -488,8 +488,8 @@ def make_opd( Args: kl_coeff: Coefficient for KL loss. Higher values push the student harder towards the teacher's distribution. Default 1.0. - length_normalize: If True, divide per-sample loss by number of - action tokens. Useful when sequences have varying lengths. + length_normalize: If True (default), divide per-sample loss by number + of action tokens. This keeps gradients stable across sequence lengths. name: Algorithm name for logging/metrics. Example: From 5fa6224fae9ae2fed0192c24994e88a3abde4ca7 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 12:02:05 +0100 Subject: [PATCH 11/17] try other variant --- docs/composition.md | 163 +++++++++++++++++++++++ examples/opd/README.md | 18 ++- examples/opd/train_opd_gsm8k.py | 63 +++------ src/ludic/training/__init__.py | 7 + src/ludic/training/algorithm.py | 169 ++++++++++++++++++++++-- src/ludic/training/credit_assignment.py | 139 ++++++++++++++++++- tests/test_credit_assignment.py | 161 ++++++++++++++++++++++ 7 files changed, 657 insertions(+), 63 deletions(-) create mode 100644 docs/composition.md diff --git a/docs/composition.md b/docs/composition.md new file mode 100644 index 0000000..4bd6ea0 --- /dev/null +++ b/docs/composition.md @@ -0,0 +1,163 @@ +# Signal Composition in Ludic + +This document describes how different training signals (rewards, advantages, losses) can be composed in Ludic's RL training pipeline. + +## The Training Pipeline + +``` +Environment ──► Rewards ──► CreditAssigner ──► Advantages ──► Loss + │ │ + ▼ ▼ + Scorers CreditModifiers + (Level 1) (Level 2) +``` + +There are **three composition levels**: + +| Level | Name | Where | Implementation | +|-------|------|-------|----------------| +| **1** | Reward | Before credit assignment | Agent scorers | +| **2** | Advantage | After credit assignment, before loss | CreditModifier | +| **3** | Loss | Separate loss terms | CompositeLoss | + +## Level 1: Reward Composition + +Add signals to rewards via Agent scorers, before credit assignment. + +```python +# Scorers attached to Agent add to per-step rewards +agent = Agent( + client=client, + scorers=[intrinsic_reward_scorer], # adds to step rewards +) +``` + +**Properties:** +- All signals go through the same credit assignment +- Signals interact (e.g., group normalization affects combined rewards) +- Tightest coupling between signals + +**Use when:** +- Intrinsic rewards should be treated identically to environment rewards +- You want signals to interact during advantage estimation + +## Level 2: Advantage Modification + +Modify advantages after credit assignment, before loss. + +```python +# KL penalty added to advantages +kl_penalty = -kl_coeff * (actor_logps - teacher_logps) +advantages = task_advantages + kl_penalty +# Then normal policy gradient with combined advantages +``` + +**Properties:** +- Each signal can have its own credit assignment strategy +- All signals go through the same importance ratio +- All signals go through the same loss function + +**Use when:** +- Different signals need different credit assignment (e.g., sparse task rewards vs dense KL) +- You want all signals to go through importance sampling together + +**Implementation in Ludic:** + +Use `CreditModifier` to add per-token signals to advantages: + +```python +algo = RLAlgorithm( + credit_assigner=GroupNormalizedReturn(group_size=8), + credit_modifiers=[KLCreditModifier(coeff=1.0)], + loss=ClippedSurrogateLoss(...), +) +``` + +Or use the preset: + +```python +algo = make_gspo_opd(group_size=8, kl_coeff=1.0) +``` + +## Level 3: Loss Composition + +Combine independent loss terms additively. + +```python +loss = CompositeLoss(terms=[ + LossTerm(name="rl", loss=ClippedSurrogateLoss(...), weight=1.0), + LossTerm(name="auxiliary", loss=SomeAuxiliaryLoss(...), weight=0.1), +]) +``` + +**Properties:** +- Each loss computed independently +- Different losses can use different data (current vs old policy logprobs) +- Loosest coupling +- Most flexible but signals don't interact + +**Use when:** +- Truly independent objectives (e.g., RL + language modeling auxiliary) +- Different losses need fundamentally different handling +- You need maximum flexibility + +## Key Differences: Advantage vs Loss Composition + +| Aspect | Advantage Modification (Level 2) | Loss Composition (Level 3) | +|--------|----------------------------------|---------------------------| +| **KL source** | Old policy (rollout time) | Current policy (forward pass) | +| **Importance sampling** | Goes through ratio | Doesn't go through ratio | +| **Gradient** | `ratio * (task_adv + kl_penalty)` | `ratio * task_adv + kl_grad` | +| **Interaction** | Signals combined | Signals independent | + +### Mathematical Difference + +**Advantage Modification:** +``` +A_t = task_advantage - kl_coeff * KL_old_t +L = E[ ratio_t * A_t ] +∇L ∝ ratio_t * A_t * ∇log π_t +``` + +**Loss Composition:** +``` +L = E[ ratio_t * task_advantage ] + kl_coeff * E[ KL_current_t ] +∇L ∝ ratio_t * task_advantage * ∇log π_t + kl_coeff * ∇log π_t +``` + +In synchronous RL with single gradient step, `ratio ≈ 1` and `KL_old ≈ KL_current`, so these are similar. They diverge with: +- Multiple epochs per batch (PPO-style) +- Async RL with stale rollouts +- Large policy updates + +## Recommended Patterns + +### Pattern 1: Pure RL (task rewards only) +```python +algo = make_gspo(group_size=8) +``` + +### Pattern 2: GSPO + OPD hybrid (recommended for distillation) +```python +algo = make_gspo_opd(group_size=8, kl_coeff=1.0) +``` + +### Pattern 3: Independent auxiliary loss +```python +algo = RLAlgorithm( + credit_assigner=GroupNormalizedReturn(group_size=8), + loss=CompositeLoss(terms=[ + LossTerm(name="rl", loss=ClippedSurrogateLoss(...), weight=1.0), + LossTerm(name="lm", loss=LanguageModelingLoss(...), weight=0.1), + ]), +) +``` + +## Summary + +| Scenario | Level | Implementation | +|----------|-------|----------------| +| Pure RL | - | `make_gspo()` | +| RL + teacher distillation | 2 (Advantage) | `make_gspo_opd()` | +| RL + unrelated auxiliary | 3 (Loss) | `CompositeLoss` | +| Intrinsic rewards | 1 (Reward) | Agent scorers | diff --git a/examples/opd/README.md b/examples/opd/README.md index 4d5cfe6..dadb01b 100644 --- a/examples/opd/README.md +++ b/examples/opd/README.md @@ -97,14 +97,22 @@ Rollouts are written to `opd_rollouts.jsonl`. ## How GSPO + OPD works +This uses "Level 2: Advantage Modification" via CreditModifier (see `docs/composition.md`): + 1. **Student samples**: The student model generates completions for GSM8K problems 2. **Environment rewards**: Each completion is graded for correctness (sparse reward) 3. **Teacher scores**: The teacher model computes per-token logprobs on the student's samples -4. **Composite loss**: Training uses two objectives: - - **GSPO**: Policy gradient with group-normalized advantages from task rewards - - **Reverse KL**: Minimizes `KL(student || teacher) = log π_student - log π_teacher` - -This gives the student task-specific learning from environment feedback while also pushing it to match the teacher's token distribution. +4. **Credit assignment**: GroupNormalizedReturn computes task-based advantages +5. **Credit modification**: KLCreditModifier adds KL penalty to advantages: + ``` + A_t = task_advantage + (-kl_coeff * (actor_logp_t - teacher_logp_t)) + ``` +6. **Policy gradient**: ClippedSurrogateLoss with modified advantages + +Key benefits of this approach (vs CompositeLoss): +- KL goes through importance sampling (multiplied by ratio like task rewards) +- KL uses old policy logprobs from rollout time (not current policy) +- All signals interact through the same loss function ## Tips diff --git a/examples/opd/train_opd_gsm8k.py b/examples/opd/train_opd_gsm8k.py index 59fb589..685e0c6 100644 --- a/examples/opd/train_opd_gsm8k.py +++ b/examples/opd/train_opd_gsm8k.py @@ -5,13 +5,17 @@ - GSPO (Group-Sorted Policy Optimization): Task rewards from GSM8K correctness - OPD (On-Policy Distillation): Dense per-token supervision from teacher -The hybrid approach uses a composite loss: - 1. ClippedSurrogateLoss (GSPO): Policy gradient with group-normalized advantages - 2. ReverseKLLoss (OPD): KL(student || teacher) = log π_student - log π_teacher +The hybrid uses "Level 2: Advantage Modification" via CreditModifier: + 1. GroupNormalizedReturn computes task-based advantages from correctness rewards + 2. KLCreditModifier adds negative KL to advantages: A_t += -coeff * (actor_logp - teacher_logp) + 3. ClippedSurrogateLoss computes policy gradient with modified advantages -This gives you both: - - Task-specific learning from environment rewards (sparse but grounded) - - Distribution matching from teacher (dense per-token feedback) +Key benefits of this approach (vs CompositeLoss): + - KL goes through importance sampling (like task rewards) + - KL uses old policy logprobs from rollout time + - All signals interact through the same loss function + +See docs/composition.md for the full composition level documentation. The key insight: teacher logprobs are an *intrinsic scorer* attached to the Agent. The scorer runs during Agent.act() and scores flow through to training. @@ -58,13 +62,10 @@ RolloutRequest, EnvSpec, ProtocolSpec, - RLAlgorithm, + make_gspo_opd, ) from ludic.training import Reducer, RichLiveLogger, PrintLogger, TeeLogger, WandbLogger, default_reducers from ludic.training.scoring import make_vllm_teacher_scorer -from ludic.training.credit_assignment import GroupNormalizedReturn -from ludic.training.loss import ClippedSurrogateLoss, ReverseKLLoss, CompositeLoss, LossTerm -from ludic.training.algorithm import validate_actor_logps from environments.gsm8k import GSM8KEnv @@ -225,38 +226,13 @@ def protocol_factory(): protocol_registry = {"single_agent": protocol_factory} - # Algorithm: GSPO (task rewards) + OPD (teacher KL) hybrid - # Credit assignment from GSPO: group-normalized returns - credit_assigner = GroupNormalizedReturn( + # Algorithm: GSPO + OPD hybrid (Advantage Composition) + # - GSPO: Task rewards with group-normalized advantages + # - OPD: KL penalty added to advantages via KLCreditModifier + # See docs/composition.md for composition level documentation + algo = make_gspo_opd( group_size=args.group_size, - normalize_adv=True, - positive_only=False, - ) - # Composite loss: PPO-style clipped surrogate + reverse KL - loss = CompositeLoss(terms=[ - LossTerm( - name="gspo", - loss=ClippedSurrogateLoss( - clip_eps_low=3e-4, - clip_eps_high=4e-4, - length_normalize=True, - ), - weight=1.0, - ), - LossTerm( - name="kl", - loss=ReverseKLLoss( - coeff=1.0, - length_normalize=True, - ), - weight=args.kl_coeff, - ), - ]) - algo = RLAlgorithm( - name="gspo_opd", - credit_assigner=credit_assigner, - loss=loss, - preprocess=validate_actor_logps, + kl_coeff=args.kl_coeff, ) # Engine + batch source @@ -340,9 +316,8 @@ def protocol_factory(): # Logger keys logger_keys = [ "train/loss", - "train/gspo/loss", - "train/kl/loss", - "train/kl/reverse_kl_mean", + "train/kl/kl_mean", # KLCreditModifier: mean reverse KL per token + "train/kl/kl_penalty_mean", # KLCreditModifier: mean penalty added to advantages "train/avg_total_reward", "train/correct_rate", "train/parse_err_rate", diff --git a/src/ludic/training/__init__.py b/src/ludic/training/__init__.py index a47e3f0..d19ef0f 100644 --- a/src/ludic/training/__init__.py +++ b/src/ludic/training/__init__.py @@ -27,6 +27,7 @@ make_cispo, make_sft, make_opd, + make_gspo_opd, ) from .credit_assignment import ( GroupNormalizedReturn, @@ -34,6 +35,8 @@ PerStepReward, EpisodicReturn, ConstantCredit, + CreditModifier, + KLCreditModifier, ) from .loss import ( Loss, @@ -87,12 +90,16 @@ "make_cispo", "make_sft", "make_opd", + "make_gspo_opd", # Credit assignment "GroupNormalizedReturn", "MonteCarloReturn", "PerStepReward", "EpisodicReturn", "ConstantCredit", + # Credit modifiers (Level 2: Advantage Modification) + "CreditModifier", + "KLCreditModifier", # Losses "Loss", "ReinforceLoss", diff --git a/src/ludic/training/algorithm.py b/src/ludic/training/algorithm.py index f82f040..d2f9d0f 100644 --- a/src/ludic/training/algorithm.py +++ b/src/ludic/training/algorithm.py @@ -1,7 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Any, Dict, Mapping, Optional, Protocol +from dataclasses import dataclass, field +from typing import Any, Dict, List, Mapping, Optional, Protocol from jaxtyping import Float from torch import nn, Tensor @@ -17,7 +17,13 @@ MaskedCausalLMCrossEntropyLoss, ReverseKLLoss, ) -from ludic.training.credit_assignment import MonteCarloReturn, GroupNormalizedReturn, ConstantCredit +from ludic.training.credit_assignment import ( + MonteCarloReturn, + GroupNormalizedReturn, + ConstantCredit, + CreditModifier, + KLCreditModifier, +) Batch = Mapping[str, Tensor] @@ -30,18 +36,37 @@ def __call__(self, saw_batch: SAWBatch) -> SAWBatch: ... @dataclass class RLAlgorithm: """ - Full RL algorithm = credit assignment + loss. - - - credit_assigner: maps Rollouts -> per-step scalar credits - (e.g. discounted returns / advantages) - - loss: consumes a collated batch (built from SAWBatch) and produces - a scalar loss and stats. - - name: identifier for logging / checkpoints + Full RL algorithm = credit assignment + credit modifiers + loss. + + Pipeline: + Rollouts → CreditAssigner → SAWBatch → Collator → Batch + ↓ + CreditModifiers + ↓ + Modified Batch + ↓ + Loss + + Components: + - credit_assigner: Rollouts → per-step scalar credits (advantages) + - credit_modifiers: Batch → Modified Batch (add per-token signals to advantages) + - loss: Batch + Logits → scalar loss + + CreditModifiers (Level 2: Advantage Modification): + Modify advantages AFTER collation, BEFORE loss. Signals added here: + - Go through importance sampling (multiplied by ratio) + - Use rollout-time values (old policy) + - Interact with task rewards through the same loss + + Example: KLCreditModifier adds teacher KL penalty to advantages. + + See docs/composition.md for the full composition level documentation. """ name: str credit_assigner: CreditAssigner loss: Loss + credit_modifiers: List[CreditModifier] = field(default_factory=list) preprocess: Optional[PreprocessFn] = None def compute_loss( @@ -50,8 +75,22 @@ def compute_loss( batch: Batch, ) -> tuple[Tensor, Dict[str, Any]]: """ - Runs the forward pass once and delegates to the Loss object. + Apply credit modifiers, run forward pass, compute loss. + + Returns: + Tuple of (loss, stats_dict) where stats_dict includes: + - Modifier metrics namespaced as "{modifier.name}/{metric}" + - Loss metrics from self.loss.compute() """ + all_stats: Dict[str, Any] = {} + + # --- Apply credit modifiers (Level 2: Advantage Modification) --- + for modifier in self.credit_modifiers: + batch, modifier_stats = modifier.modify(batch) + # Namespace modifier metrics + for key, value in modifier_stats.items(): + all_stats[f"{modifier.name}/{key}"] = value + # --- Run the forward pass --- input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] @@ -61,8 +100,11 @@ def compute_loss( ) logits: Logits = outputs.logits - # Pass the resulting logits to the loss function - return self.loss.compute(logits, batch) + # --- Compute loss --- + loss, loss_stats = self.loss.compute(logits, batch) + all_stats.update(loss_stats) + + return loss, all_stats # --------------------------------------------------------------------------- @@ -516,6 +558,10 @@ def make_opd( ``` Reference: https://thinkingmachines.ai/blog/on-policy-distillation + + Note: + This uses Loss Composition (Level 3). For OPD where KL goes through + importance sampling (Level 2: Advantage Modification), use make_gspo_opd() instead. """ credit_assigner: CreditAssigner = ConstantCredit(value=1.0) loss: Loss = ReverseKLLoss(coeff=kl_coeff, length_normalize=length_normalize) @@ -525,3 +571,100 @@ def make_opd( credit_assigner=credit_assigner, loss=loss, ) + + +def make_gspo_opd( + *, + group_size: int, + kl_coeff: float = 1.0, + group_normalize_adv: bool = True, + positive_only: bool = False, + clip_eps_low: float = 3e-4, + clip_eps_high: float = 4e-4, + length_normalize: bool = True, + ratio_clip: Optional[float] = None, + drop_zero_weight: bool = False, + drop_zero_weight_eps: float = 1e-4, + name: str = "gspo_opd", +) -> RLAlgorithm: + """ + GSPO + OPD hybrid using advantage composition. + + Combines: + - GSPO: Task rewards with group-normalized advantages + - OPD: Teacher KL penalty added to advantages (Level 2: Advantage Modification) + + KL is computed at rollout time and added to advantages BEFORE the loss. + This means: + - KL goes through importance sampling (multiplied by ratio) + - KL uses old policy logprobs (from rollout), not current policy + - All signals interact through the same loss function + + Pipeline: + 1. GroupNormalizedReturn computes task-based advantages + 2. KLCreditModifier adds negative KL to advantages + 3. ClippedSurrogateLoss computes policy gradient with modified advantages + + Args: + group_size: Number of rollouts per group for advantage normalization. + kl_coeff: Coefficient for KL penalty. Higher = stronger teacher matching. + group_normalize_adv: Normalize advantages within each group. + positive_only: Clip negative advantages to zero. + clip_eps_low: Lower PPO clipping epsilon. + clip_eps_high: Upper PPO clipping epsilon. + length_normalize: Normalize loss by number of action tokens. + ratio_clip: Optional upper bound for ratio truncation. + drop_zero_weight: Drop zero-advantage samples before collation. + drop_zero_weight_eps: Epsilon for zero-weight detection. + name: Algorithm name for logging. + + Prerequisites: + - Rollouts must return actor logprobs (ReturnSpec.for_rl()) + - Agent must have a teacher scorer (make_vllm_teacher_scorer()) + - Rollouts must have group_id for GSPO advantage normalization + + Example: + ```python + from ludic.training import make_gspo_opd + from ludic.training.scoring import make_vllm_teacher_scorer + + teacher_scorer = make_vllm_teacher_scorer( + base_url="http://localhost:8001", + model="Qwen/Qwen3-32B", + ) + + agent = Agent(client=client, ..., scorers=[teacher_scorer]) + algo = make_gspo_opd(group_size=8, kl_coeff=1.0) + trainer = Trainer(model=model, algo=algo, ...) + ``` + + Reference: https://thinkingmachines.ai/blog/on-policy-distillation + """ + credit_assigner: CreditAssigner = GroupNormalizedReturn( + group_size=group_size, + normalize_adv=group_normalize_adv, + positive_only=positive_only, + ) + + credit_modifiers = [KLCreditModifier(coeff=kl_coeff)] + + loss: Loss = ClippedSurrogateLoss( + clip_eps_low=clip_eps_low, + clip_eps_high=clip_eps_high, + length_normalize=length_normalize, + ratio_clip=ratio_clip, + ) + + preprocess_fns = [] + if drop_zero_weight: + preprocess_fns.append(lambda batch: drop_zero_weight_samples(batch, eps=drop_zero_weight_eps)) + preprocess_fns.append(validate_actor_logps) + preprocess = compose_preprocess(*preprocess_fns) + + return RLAlgorithm( + name=name, + credit_assigner=credit_assigner, + credit_modifiers=credit_modifiers, + loss=loss, + preprocess=preprocess, + ) diff --git a/src/ludic/training/credit_assignment.py b/src/ludic/training/credit_assignment.py index c17c351..69056d0 100644 --- a/src/ludic/training/credit_assignment.py +++ b/src/ludic/training/credit_assignment.py @@ -2,13 +2,150 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Dict, List +from typing import Any, Dict, List, Mapping, Protocol, Tuple import torch +from torch import Tensor + from ludic.types import Rollout from ludic.training.types import RolloutStepKey +Batch = Mapping[str, Tensor] + + +# --------------------------------------------------------------------------- +# Credit Modifiers (Level 2: Advantage Modification) +# --------------------------------------------------------------------------- +# +# CreditModifiers operate on collated batches (after credit assignment) to +# modify the per-token advantages/weights before loss computation. +# +# Key use case: OPD where KL penalty is added to advantages, causing it to +# go through importance sampling like task rewards. +# +# See docs/composition.md for the full composition level documentation. +# --------------------------------------------------------------------------- + + +class CreditModifier(Protocol): + """ + Modifies batch advantages after collation, before loss computation. + + This is "Level 2: Advantage Modification" - adding per-token signals + (like KL penalty) to trajectory-level advantages from credit assignment. + + Unlike CompositeLoss (Level 3: Loss Composition), signals added here: + - Go through importance sampling (multiplied by ratio) + - Use rollout-time (old policy) values, not current policy + - Interact with task rewards through the same loss function + + Example: + >>> modifier = KLCreditModifier(coeff=1.0) + >>> batch, metrics = modifier.modify(batch) + >>> # batch["weight"] now includes KL penalty + """ + + name: str + + def modify(self, batch: Batch) -> Tuple[Batch, Dict[str, Any]]: + """ + Modify batch advantages and return metrics. + + Args: + batch: Collated batch with at least: + - "weight": [B, T] per-token advantages from credit assignment + - Other fields depend on the modifier (e.g., "actor_logps", "teacher_logps") + + Returns: + Tuple of (modified_batch, metrics_dict). + The batch should be modified in-place for efficiency. + Metrics are namespaced under self.name by the caller. + """ + ... + + +@dataclass +class KLCreditModifier: + """ + Add negative reverse KL to advantages for on-policy distillation. + + Implements KL penalty for OPD: + kl_advantage_t = -coeff * (actor_logps_t - teacher_logps_t) + + The KL is computed at rollout time (old policy), so it goes through + importance sampling when used with ratio-based losses like PPO/GSPO. + + Args: + coeff: Coefficient for KL penalty. Higher = stronger teacher matching. + name: Modifier name for logging. Metrics appear as "{name}/kl_mean", etc. + + Requires batch to have: + - "actor_logps": [B, T] old policy logprobs from rollout + - "teacher_logps": [B, T] teacher logprobs + + Example: + >>> algo = RLAlgorithm( + ... credit_assigner=GroupNormalizedReturn(group_size=8), + ... credit_modifiers=[KLCreditModifier(coeff=1.0)], + ... loss=ClippedSurrogateLoss(...), + ... ) + """ + + coeff: float = 1.0 + name: str = "kl" + + def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: + if "actor_logps" not in batch: + raise KeyError( + "KLCreditModifier requires batch['actor_logps']. " + "Ensure rollouts return actor logprobs (ReturnSpec.for_rl())." + ) + if "teacher_logps" not in batch: + raise KeyError( + "KLCreditModifier requires batch['teacher_logps']. " + "Ensure agent has a teacher scorer (e.g., make_vllm_teacher_scorer())." + ) + + actor_logps = batch["actor_logps"] # [B, T] + teacher_logps = batch["teacher_logps"] # [B, T] + action_mask = batch["action_mask"] # [B, T] + weight = batch["weight"] # [B, T] + + # Reverse KL: log π_student - log π_teacher + # We want to minimize this, so we add NEGATIVE KL to advantages + reverse_kl = actor_logps - teacher_logps # [B, T] + kl_penalty = -self.coeff * reverse_kl # [B, T] + + # Add to weight (advantage), masked to action tokens only + # Note: action_mask should already be applied to weight, but we + # apply it to kl_penalty too for safety + modified_weight = weight + kl_penalty * action_mask.float() + + # Create modified batch (shallow copy with updated weight) + modified_batch = dict(batch) + modified_batch["weight"] = modified_weight + + # Compute metrics (masked to action tokens) + mask_sum = action_mask.sum() + if mask_sum > 0: + masked_kl = reverse_kl * action_mask.float() + kl_mean = masked_kl.sum() / mask_sum + kl_std = ((masked_kl - kl_mean * action_mask.float()) ** 2).sum() / mask_sum + kl_std = kl_std.sqrt() + else: + kl_mean = reverse_kl.new_zeros(()) + kl_std = reverse_kl.new_zeros(()) + + metrics = { + "kl_mean": kl_mean.detach(), + "kl_std": kl_std.detach(), + "kl_penalty_mean": (kl_penalty * action_mask.float()).sum().detach() / max(mask_sum, 1), + } + + return modified_batch, metrics + + # ---- Credit Assigners ---- @dataclass diff --git a/tests/test_credit_assignment.py b/tests/test_credit_assignment.py index f02b61b..d76f24c 100644 --- a/tests/test_credit_assignment.py +++ b/tests/test_credit_assignment.py @@ -1,5 +1,6 @@ import pytest import math +import torch from ludic.types import Rollout, EnvironmentStep, TokenTrace from ludic.training.credit_assignment import ( @@ -7,6 +8,7 @@ EpisodicReturn, PerStepReward, GroupNormalizedReturn, + KLCreditModifier, ) # ---- Helper to build a simple rollout ---- @@ -243,3 +245,162 @@ def test_group_normalized_return_invalid_group_size(): with pytest.raises(ValueError, match="group_size must be positive"): GroupNormalizedReturn(group_size=-1) + + +# ---- KLCreditModifier Tests ---- + +def _make_batch( + *, + weight: torch.Tensor, + actor_logps: torch.Tensor, + teacher_logps: torch.Tensor, + action_mask: torch.Tensor, +) -> dict: + """Helper to create a batch dict for KLCreditModifier tests.""" + return { + "weight": weight, + "actor_logps": actor_logps, + "teacher_logps": teacher_logps, + "action_mask": action_mask, + } + + +def test_kl_credit_modifier_basic(): + """Test basic KL penalty addition to advantages.""" + # Batch of 2 samples, 4 tokens each + # actor_logps - teacher_logps = reverse KL + # KL penalty = -coeff * reverse_kl + weight = torch.zeros(2, 4) + actor_logps = torch.tensor([ + [-1.0, -2.0, -1.5, -0.5], # sample 0 + [-2.0, -1.0, -3.0, -1.0], # sample 1 + ]) + teacher_logps = torch.tensor([ + [-1.5, -1.5, -1.5, -1.5], # sample 0: teacher + [-1.5, -1.5, -1.5, -1.5], # sample 1: teacher + ]) + action_mask = torch.tensor([ + [0, 1, 1, 1], # sample 0: first token is prompt + [0, 0, 1, 1], # sample 1: first two tokens are prompt + ]) + + batch = _make_batch( + weight=weight, + actor_logps=actor_logps, + teacher_logps=teacher_logps, + action_mask=action_mask, + ) + + modifier = KLCreditModifier(coeff=1.0) + modified_batch, metrics = modifier.modify(batch) + + # reverse_kl = actor - teacher + # sample 0: [-1.0 - (-1.5), -2.0 - (-1.5), -1.5 - (-1.5), -0.5 - (-1.5)] + # = [0.5, -0.5, 0.0, 1.0] + # sample 1: [-2.0 - (-1.5), -1.0 - (-1.5), -3.0 - (-1.5), -1.0 - (-1.5)] + # = [-0.5, 0.5, -1.5, 0.5] + # kl_penalty = -1.0 * reverse_kl (masked) + # sample 0: [0, 0.5, 0.0, -1.0] (first token masked) + # sample 1: [0, 0, 1.5, -0.5] (first two masked) + + expected_weight = torch.tensor([ + [0.0, 0.5, 0.0, -1.0], + [0.0, 0.0, 1.5, -0.5], + ]) + + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) + + # Check metrics + assert "kl_mean" in metrics + assert "kl_std" in metrics + assert "kl_penalty_mean" in metrics + + +def test_kl_credit_modifier_with_existing_advantage(): + """Test that KL penalty is added to existing advantages.""" + # Start with non-zero advantages + weight = torch.tensor([ + [0.0, 1.0, 1.0, 1.0], # sample 0 + [0.0, 0.0, 2.0, 2.0], # sample 1 + ]) + actor_logps = torch.tensor([ + [-1.0, -1.0, -1.0, -1.0], + [-1.0, -1.0, -1.0, -1.0], + ]) + teacher_logps = torch.tensor([ + [-2.0, -2.0, -2.0, -2.0], # actor closer to 0 = higher prob + [-2.0, -2.0, -2.0, -2.0], + ]) + action_mask = torch.tensor([ + [0, 1, 1, 1], + [0, 0, 1, 1], + ]) + + batch = _make_batch( + weight=weight, + actor_logps=actor_logps, + teacher_logps=teacher_logps, + action_mask=action_mask, + ) + + modifier = KLCreditModifier(coeff=1.0) + modified_batch, _ = modifier.modify(batch) + + # reverse_kl = -1.0 - (-2.0) = 1.0 everywhere + # kl_penalty = -1.0 * 1.0 = -1.0 (penalty for being overconfident vs teacher) + # Modified weight = original + penalty (masked) + expected_weight = torch.tensor([ + [0.0, 1.0 - 1.0, 1.0 - 1.0, 1.0 - 1.0], # [0, 0, 0, 0] + [0.0, 0.0, 2.0 - 1.0, 2.0 - 1.0], # [0, 0, 1, 1] + ]) + + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) + + +def test_kl_credit_modifier_coeff(): + """Test that coeff scales the KL penalty.""" + weight = torch.zeros(1, 3) + actor_logps = torch.tensor([[-1.0, -1.0, -1.0]]) + teacher_logps = torch.tensor([[-2.0, -2.0, -2.0]]) + action_mask = torch.tensor([[1, 1, 1]]) + + batch = _make_batch( + weight=weight, + actor_logps=actor_logps, + teacher_logps=teacher_logps, + action_mask=action_mask, + ) + + # With coeff=2.0, penalty should be doubled + modifier = KLCreditModifier(coeff=2.0) + modified_batch, _ = modifier.modify(batch) + + # reverse_kl = 1.0, penalty = -2.0 * 1.0 = -2.0 + expected_weight = torch.tensor([[-2.0, -2.0, -2.0]]) + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) + + +def test_kl_credit_modifier_missing_actor_logps(): + """Test that missing actor_logps raises KeyError.""" + batch = { + "weight": torch.zeros(1, 3), + "teacher_logps": torch.zeros(1, 3), + "action_mask": torch.ones(1, 3), + } + + modifier = KLCreditModifier() + with pytest.raises(KeyError, match="actor_logps"): + modifier.modify(batch) + + +def test_kl_credit_modifier_missing_teacher_logps(): + """Test that missing teacher_logps raises KeyError.""" + batch = { + "weight": torch.zeros(1, 3), + "actor_logps": torch.zeros(1, 3), + "action_mask": torch.ones(1, 3), + } + + modifier = KLCreditModifier() + with pytest.raises(KeyError, match="teacher_logps"): + modifier.modify(batch) From d939fdaf22b0ca80185cb155cfc0643a487ec49b Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 15:45:13 +0100 Subject: [PATCH 12/17] fix bugs --- src/ludic/training/credit_assignment.py | 11 ++++---- tests/test_credit_assignment.py | 35 +++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/src/ludic/training/credit_assignment.py b/src/ludic/training/credit_assignment.py index 69056d0..85242e1 100644 --- a/src/ludic/training/credit_assignment.py +++ b/src/ludic/training/credit_assignment.py @@ -117,10 +117,11 @@ def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: reverse_kl = actor_logps - teacher_logps # [B, T] kl_penalty = -self.coeff * reverse_kl # [B, T] - # Add to weight (advantage), masked to action tokens only - # Note: action_mask should already be applied to weight, but we - # apply it to kl_penalty too for safety - modified_weight = weight + kl_penalty * action_mask.float() + # Add KL penalty to weight (advantage), then mask to action tokens only. + # We apply action_mask to the entire sum to ensure prompt tokens have + # zero weight, regardless of how the upstream credit assigner populated + # the weight tensor. This prevents prompt-length-dependent loss scaling. + modified_weight = (weight + kl_penalty) * action_mask.float() # Create modified batch (shallow copy with updated weight) modified_batch = dict(batch) @@ -140,7 +141,7 @@ def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: metrics = { "kl_mean": kl_mean.detach(), "kl_std": kl_std.detach(), - "kl_penalty_mean": (kl_penalty * action_mask.float()).sum().detach() / max(mask_sum, 1), + "kl_penalty_mean": (kl_penalty * action_mask.float()).sum().detach() / mask_sum.clamp(min=1), } return modified_batch, metrics diff --git a/tests/test_credit_assignment.py b/tests/test_credit_assignment.py index d76f24c..913bfdc 100644 --- a/tests/test_credit_assignment.py +++ b/tests/test_credit_assignment.py @@ -404,3 +404,38 @@ def test_kl_credit_modifier_missing_teacher_logps(): modifier = KLCreditModifier() with pytest.raises(KeyError, match="teacher_logps"): modifier.modify(batch) + + +def test_kl_credit_modifier_zeroes_prompt_tokens(): + """ + Test that prompt tokens (action_mask=0) always have zero weight, + even if the input weight tensor had non-zero values there. + + This prevents prompt-length-dependent loss scaling. + """ + # Input weight has non-zero values on ALL tokens (as might happen + # if trajectory-level advantages were broadcast without masking) + weight = torch.tensor([ + [5.0, 5.0, 5.0, 5.0], # advantage of 5.0 broadcast to all tokens + ]) + actor_logps = torch.tensor([[-1.0, -1.0, -1.0, -1.0]]) + teacher_logps = torch.tensor([[-1.0, -1.0, -1.0, -1.0]]) # identical, so KL=0 + action_mask = torch.tensor([ + [0, 0, 1, 1], # first two tokens are prompt + ]) + + batch = _make_batch( + weight=weight, + actor_logps=actor_logps, + teacher_logps=teacher_logps, + action_mask=action_mask, + ) + + modifier = KLCreditModifier(coeff=1.0) + modified_batch, _ = modifier.modify(batch) + + # Key assertion: prompt tokens (positions 0, 1) must have zero weight, + # even though input weight was 5.0 there + # Action tokens (positions 2, 3) should have weight=5.0 (advantage + 0 KL penalty) + expected_weight = torch.tensor([[0.0, 0.0, 5.0, 5.0]]) + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) From 06aa735e15ba89f2436e6734a7170e4ac53b54ff Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 16:05:55 +0100 Subject: [PATCH 13/17] resolve shape mismatch --- src/ludic/training/credit_assignment.py | 25 ++++++- tests/test_credit_assignment.py | 98 +++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 2 deletions(-) diff --git a/src/ludic/training/credit_assignment.py b/src/ludic/training/credit_assignment.py index 85242e1..1c6ade7 100644 --- a/src/ludic/training/credit_assignment.py +++ b/src/ludic/training/credit_assignment.py @@ -110,7 +110,28 @@ def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: actor_logps = batch["actor_logps"] # [B, T] teacher_logps = batch["teacher_logps"] # [B, T] action_mask = batch["action_mask"] # [B, T] - weight = batch["weight"] # [B, T] + weight = batch["weight"] # [B], [B, C], or [B, T] + + B, T = action_mask.shape + + # Handle different weight shapes: + # - [B]: turn-level weight (one per sample) -> broadcast to [B, 1] + # - [B, C]: completion-only weight (C < T) -> expand to [B, T] + # - [B, T]: already token-level -> use as-is + if weight.dim() == 1: + # Turn-level: broadcast to all tokens + weight = weight.unsqueeze(-1) # [B] -> [B, 1] for broadcasting + elif weight.shape[-1] != T: + # Completion-only: expand to full sequence using action_mask positions. + # weight[b, :n_actions] goes to positions where action_mask[b] == 1. + # We handle each sample separately since they may have different + # numbers of action tokens (variable completion lengths). + expanded_weight = torch.zeros(B, T, device=weight.device, dtype=weight.dtype) + for b in range(B): + action_indices = action_mask[b].nonzero(as_tuple=True)[0] + n_actions = len(action_indices) + expanded_weight[b, action_indices] = weight[b, :n_actions] + weight = expanded_weight # Reverse KL: log π_student - log π_teacher # We want to minimize this, so we add NEGATIVE KL to advantages @@ -121,7 +142,7 @@ def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: # We apply action_mask to the entire sum to ensure prompt tokens have # zero weight, regardless of how the upstream credit assigner populated # the weight tensor. This prevents prompt-length-dependent loss scaling. - modified_weight = (weight + kl_penalty) * action_mask.float() + modified_weight = (weight + kl_penalty) * action_mask.float() # [B, T] # Create modified batch (shallow copy with updated weight) modified_batch = dict(batch) diff --git a/tests/test_credit_assignment.py b/tests/test_credit_assignment.py index 913bfdc..d80b9c5 100644 --- a/tests/test_credit_assignment.py +++ b/tests/test_credit_assignment.py @@ -439,3 +439,101 @@ def test_kl_credit_modifier_zeroes_prompt_tokens(): # Action tokens (positions 2, 3) should have weight=5.0 (advantage + 0 KL penalty) expected_weight = torch.tensor([[0.0, 0.0, 5.0, 5.0]]) assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) + + +def test_kl_credit_modifier_turn_level_weight(): + """ + Test that turn-level (1D) weights are broadcast to token-level. + + Credit assigners produce one weight per turn/step, but KL is per-token. + The modifier should broadcast the turn weight to all action tokens. + """ + # Turn-level weight: one value per sample (shape [B]) + weight = torch.tensor([2.0, -1.0]) # two samples with advantages 2.0 and -1.0 + + # Token-level tensors: shape [B, T] + actor_logps = torch.tensor([ + [-1.0, -1.0, -1.0, -1.0], + [-2.0, -2.0, -2.0, -2.0], + ]) + teacher_logps = torch.tensor([ + [-1.5, -1.5, -1.5, -1.5], # KL = -1.0 - (-1.5) = 0.5 + [-1.5, -1.5, -1.5, -1.5], # KL = -2.0 - (-1.5) = -0.5 + ]) + action_mask = torch.tensor([ + [0, 1, 1, 1], # sample 0: 3 action tokens + [0, 0, 1, 1], # sample 1: 2 action tokens + ]) + + batch = { + "weight": weight, # [B] - turn-level + "actor_logps": actor_logps, + "teacher_logps": teacher_logps, + "action_mask": action_mask, + } + + modifier = KLCreditModifier(coeff=1.0) + modified_batch, _ = modifier.modify(batch) + + # Sample 0: turn_weight=2.0, kl_penalty=-0.5 per token + # modified = (2.0 + (-0.5)) * mask = 1.5 on action tokens + # Sample 1: turn_weight=-1.0, kl_penalty=0.5 per token + # modified = (-1.0 + 0.5) * mask = -0.5 on action tokens + expected_weight = torch.tensor([ + [0.0, 1.5, 1.5, 1.5], + [0.0, 0.0, -0.5, -0.5], + ]) + + assert modified_batch["weight"].shape == (2, 4) # now token-level + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) + + +def test_kl_credit_modifier_completion_only_weight(): + """ + Test that completion-only weights [B, C] are expanded to full sequence [B, T]. + + When weight is stored compactly for just completion tokens (C < T), + the modifier expands it to full sequence length using action_mask positions. + """ + # Completion-only weight: [B, C] where C is number of completion tokens + # Sample 0 has 3 completion tokens, sample 1 has 2 + # But they're padded to same length C=3 + weight = torch.tensor([ + [1.0, 2.0, 3.0], # sample 0: weights for 3 action tokens + [4.0, 5.0, 0.0], # sample 1: weights for 2 action tokens (padded) + ]) + + # Full sequence tensors: [B, T] where T=6 (prompt + completion) + actor_logps = torch.tensor([ + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], + ]) + teacher_logps = torch.tensor([ + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], # KL = 0 + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], + ]) + action_mask = torch.tensor([ + [0, 0, 0, 1, 1, 1], # sample 0: 3 prompt, 3 completion + [0, 0, 0, 0, 1, 1], # sample 1: 4 prompt, 2 completion + ]) + + batch = { + "weight": weight, # [B, C=3] - completion-only + "actor_logps": actor_logps, + "teacher_logps": teacher_logps, + "action_mask": action_mask, + } + + modifier = KLCreditModifier(coeff=1.0) + modified_batch, _ = modifier.modify(batch) + + # With KL=0, modified_weight = weight * action_mask + # Sample 0: weights [1, 2, 3] placed at positions [3, 4, 5] + # Sample 1: weights [4, 5] placed at positions [4, 5] (the 0.0 padding is ignored) + expected_weight = torch.tensor([ + [0.0, 0.0, 0.0, 1.0, 2.0, 3.0], + [0.0, 0.0, 0.0, 0.0, 4.0, 5.0], + ]) + + assert modified_batch["weight"].shape == (2, 6) + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) From 592a1ebc94a4b514cb5dd53ed350f460638128c3 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 16:14:03 +0100 Subject: [PATCH 14/17] fix shape mismatch again --- src/ludic/training/credit_assignment.py | 73 +++++++++++++++---------- tests/test_credit_assignment.py | 42 +++++++------- 2 files changed, 64 insertions(+), 51 deletions(-) diff --git a/src/ludic/training/credit_assignment.py b/src/ludic/training/credit_assignment.py index 1c6ade7..64ca99c 100644 --- a/src/ludic/training/credit_assignment.py +++ b/src/ludic/training/credit_assignment.py @@ -107,42 +107,55 @@ def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: "Ensure agent has a teacher scorer (e.g., make_vllm_teacher_scorer())." ) - actor_logps = batch["actor_logps"] # [B, T] - teacher_logps = batch["teacher_logps"] # [B, T] - action_mask = batch["action_mask"] # [B, T] + actor_logps = batch["actor_logps"] # [B, T] full sequence + teacher_logps = batch["teacher_logps"] # [B, T] full sequence + action_mask = batch["action_mask"] # [B, T] full sequence weight = batch["weight"] # [B], [B, C], or [B, T] B, T = action_mask.shape - # Handle different weight shapes: - # - [B]: turn-level weight (one per sample) -> broadcast to [B, 1] - # - [B, C]: completion-only weight (C < T) -> expand to [B, T] - # - [B, T]: already token-level -> use as-is - if weight.dim() == 1: - # Turn-level: broadcast to all tokens - weight = weight.unsqueeze(-1) # [B] -> [B, 1] for broadcasting - elif weight.shape[-1] != T: - # Completion-only: expand to full sequence using action_mask positions. - # weight[b, :n_actions] goes to positions where action_mask[b] == 1. - # We handle each sample separately since they may have different - # numbers of action tokens (variable completion lengths). - expanded_weight = torch.zeros(B, T, device=weight.device, dtype=weight.dtype) + # Reverse KL: log π_student - log π_teacher (full sequence) + # We want to minimize this, so we add NEGATIVE KL to advantages + reverse_kl = actor_logps - teacher_logps # [B, T] + kl_penalty_full = -self.coeff * reverse_kl # [B, T] + + # Handle different weight shapes, keeping output in same format as input. + # The loss function expects weight to match ratio shape, which may be + # completion-only [B, C] rather than full sequence [B, T]. + if weight.shape[-1] == T: + # Full sequence [B, T]: add KL directly, mask to action tokens. + modified_weight = (weight + kl_penalty_full) * action_mask.float() + else: + # Turn-level [B] or completion-only [B, C]: extract KL for action tokens. + # We need completion-only KL penalty to match the weight format. + # Determine max completion length from action_mask. + completion_lens = action_mask.sum(dim=-1).long() # [B] + max_completion_len = int(completion_lens.max().item()) + + # Extract KL penalty for action tokens only -> [B, max_completion_len] + kl_penalty_completion = torch.zeros( + B, max_completion_len, device=weight.device, dtype=weight.dtype + ) + # Also build completion-only mask for padding positions + completion_mask = torch.zeros( + B, max_completion_len, device=weight.device, dtype=weight.dtype + ) for b in range(B): action_indices = action_mask[b].nonzero(as_tuple=True)[0] n_actions = len(action_indices) - expanded_weight[b, action_indices] = weight[b, :n_actions] - weight = expanded_weight - - # Reverse KL: log π_student - log π_teacher - # We want to minimize this, so we add NEGATIVE KL to advantages - reverse_kl = actor_logps - teacher_logps # [B, T] - kl_penalty = -self.coeff * reverse_kl # [B, T] - - # Add KL penalty to weight (advantage), then mask to action tokens only. - # We apply action_mask to the entire sum to ensure prompt tokens have - # zero weight, regardless of how the upstream credit assigner populated - # the weight tensor. This prevents prompt-length-dependent loss scaling. - modified_weight = (weight + kl_penalty) * action_mask.float() # [B, T] + kl_penalty_completion[b, :n_actions] = kl_penalty_full[b, action_indices] + completion_mask[b, :n_actions] = 1.0 + + if weight.dim() == 1: + # Turn-level [B]: broadcast to completion-only, add per-token KL. + # Output is [B, max_completion_len] to match ratio shape. + # Zero out padding positions. + modified_weight = (weight.unsqueeze(-1) + kl_penalty_completion) * completion_mask + else: + # Completion-only [B, C]: add KL directly. + # C should equal max_completion_len (or be padded similarly). + C = weight.shape[-1] + modified_weight = weight + kl_penalty_completion[:, :C] # Create modified batch (shallow copy with updated weight) modified_batch = dict(batch) @@ -162,7 +175,7 @@ def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: metrics = { "kl_mean": kl_mean.detach(), "kl_std": kl_std.detach(), - "kl_penalty_mean": (kl_penalty * action_mask.float()).sum().detach() / mask_sum.clamp(min=1), + "kl_penalty_mean": (kl_penalty_full * action_mask.float()).sum().detach() / mask_sum.clamp(min=1), } return modified_batch, metrics diff --git a/tests/test_credit_assignment.py b/tests/test_credit_assignment.py index d80b9c5..0f51469 100644 --- a/tests/test_credit_assignment.py +++ b/tests/test_credit_assignment.py @@ -443,10 +443,11 @@ def test_kl_credit_modifier_zeroes_prompt_tokens(): def test_kl_credit_modifier_turn_level_weight(): """ - Test that turn-level (1D) weights are broadcast to token-level. + Test that turn-level (1D) weights are broadcast to completion-only format. Credit assigners produce one weight per turn/step, but KL is per-token. - The modifier should broadcast the turn weight to all action tokens. + The modifier broadcasts the turn weight and adds per-token KL, outputting + completion-only [B, max_completion_len] to match ratio shape in loss. """ # Turn-level weight: one value per sample (shape [B]) weight = torch.tensor([2.0, -1.0]) # two samples with advantages 2.0 and -1.0 @@ -457,12 +458,12 @@ def test_kl_credit_modifier_turn_level_weight(): [-2.0, -2.0, -2.0, -2.0], ]) teacher_logps = torch.tensor([ - [-1.5, -1.5, -1.5, -1.5], # KL = -1.0 - (-1.5) = 0.5 - [-1.5, -1.5, -1.5, -1.5], # KL = -2.0 - (-1.5) = -0.5 + [-1.5, -1.5, -1.5, -1.5], # KL = -1.0 - (-1.5) = 0.5, penalty = -0.5 + [-1.5, -1.5, -1.5, -1.5], # KL = -2.0 - (-1.5) = -0.5, penalty = 0.5 ]) action_mask = torch.tensor([ - [0, 1, 1, 1], # sample 0: 3 action tokens - [0, 0, 1, 1], # sample 1: 2 action tokens + [0, 1, 1, 1], # sample 0: 3 action tokens (positions 1,2,3) + [0, 0, 1, 1], # sample 1: 2 action tokens (positions 2,3) ]) batch = { @@ -475,25 +476,25 @@ def test_kl_credit_modifier_turn_level_weight(): modifier = KLCreditModifier(coeff=1.0) modified_batch, _ = modifier.modify(batch) - # Sample 0: turn_weight=2.0, kl_penalty=-0.5 per token - # modified = (2.0 + (-0.5)) * mask = 1.5 on action tokens - # Sample 1: turn_weight=-1.0, kl_penalty=0.5 per token - # modified = (-1.0 + 0.5) * mask = -0.5 on action tokens + # Output is completion-only [B, max_completion_len=3] + # Sample 0: weight=2.0 + kl_penalty=[-0.5, -0.5, -0.5] -> [1.5, 1.5, 1.5] + # Sample 1: weight=-1.0 + kl_penalty=[0.5, 0.5, 0.0] -> [-0.5, -0.5, 0.0] + # (only 2 action tokens, so 3rd position is 0 padding) expected_weight = torch.tensor([ - [0.0, 1.5, 1.5, 1.5], - [0.0, 0.0, -0.5, -0.5], + [1.5, 1.5, 1.5], + [-0.5, -0.5, 0.0], # 3rd is padding (no action token there) ]) - assert modified_batch["weight"].shape == (2, 4) # now token-level + assert modified_batch["weight"].shape == (2, 3) # completion-only assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) def test_kl_credit_modifier_completion_only_weight(): """ - Test that completion-only weights [B, C] are expanded to full sequence [B, T]. + Test that completion-only weights [B, C] stay in completion-only format. When weight is stored compactly for just completion tokens (C < T), - the modifier expands it to full sequence length using action_mask positions. + the modifier extracts KL for those positions and adds it, keeping [B, C] shape. """ # Completion-only weight: [B, C] where C is number of completion tokens # Sample 0 has 3 completion tokens, sample 1 has 2 @@ -527,13 +528,12 @@ def test_kl_credit_modifier_completion_only_weight(): modifier = KLCreditModifier(coeff=1.0) modified_batch, _ = modifier.modify(batch) - # With KL=0, modified_weight = weight * action_mask - # Sample 0: weights [1, 2, 3] placed at positions [3, 4, 5] - # Sample 1: weights [4, 5] placed at positions [4, 5] (the 0.0 padding is ignored) + # With KL=0, modified_weight = weight (unchanged since kl_penalty is 0) + # Output stays in completion-only format [B, C] expected_weight = torch.tensor([ - [0.0, 0.0, 0.0, 1.0, 2.0, 3.0], - [0.0, 0.0, 0.0, 0.0, 4.0, 5.0], + [1.0, 2.0, 3.0], # sample 0: unchanged + [4.0, 5.0, 0.0], # sample 1: unchanged (padding preserved) ]) - assert modified_batch["weight"].shape == (2, 6) + assert modified_batch["weight"].shape == (2, 3) # stays completion-only assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) From 8fb93b979491c7a16bde69eecdc779632eddcc25 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 16:20:47 +0100 Subject: [PATCH 15/17] let Codex remove Opus' delusions --- src/ludic/training/credit_assignment.py | 39 +++++++++---------------- 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/src/ludic/training/credit_assignment.py b/src/ludic/training/credit_assignment.py index 64ca99c..31459ee 100644 --- a/src/ludic/training/credit_assignment.py +++ b/src/ludic/training/credit_assignment.py @@ -122,40 +122,29 @@ def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: # Handle different weight shapes, keeping output in same format as input. # The loss function expects weight to match ratio shape, which may be # completion-only [B, C] rather than full sequence [B, T]. - if weight.shape[-1] == T: + if weight.dim() == 1: + # Turn-level [B]: add completion KL summed over action tokens. + kl_penalty_scalar = (kl_penalty_full * action_mask.float()).sum(dim=-1) + modified_weight = weight + kl_penalty_scalar + elif weight.shape[-1] == T: # Full sequence [B, T]: add KL directly, mask to action tokens. modified_weight = (weight + kl_penalty_full) * action_mask.float() else: - # Turn-level [B] or completion-only [B, C]: extract KL for action tokens. - # We need completion-only KL penalty to match the weight format. - # Determine max completion length from action_mask. - completion_lens = action_mask.sum(dim=-1).long() # [B] - max_completion_len = int(completion_lens.max().item()) - - # Extract KL penalty for action tokens only -> [B, max_completion_len] + # Completion-only [B, C]: extract KL for action tokens and align to completion positions. + C = weight.shape[-1] kl_penalty_completion = torch.zeros( - B, max_completion_len, device=weight.device, dtype=weight.dtype + B, C, device=weight.device, dtype=weight.dtype ) - # Also build completion-only mask for padding positions completion_mask = torch.zeros( - B, max_completion_len, device=weight.device, dtype=weight.dtype + B, C, device=weight.device, dtype=weight.dtype ) for b in range(B): action_indices = action_mask[b].nonzero(as_tuple=True)[0] - n_actions = len(action_indices) - kl_penalty_completion[b, :n_actions] = kl_penalty_full[b, action_indices] - completion_mask[b, :n_actions] = 1.0 - - if weight.dim() == 1: - # Turn-level [B]: broadcast to completion-only, add per-token KL. - # Output is [B, max_completion_len] to match ratio shape. - # Zero out padding positions. - modified_weight = (weight.unsqueeze(-1) + kl_penalty_completion) * completion_mask - else: - # Completion-only [B, C]: add KL directly. - # C should equal max_completion_len (or be padded similarly). - C = weight.shape[-1] - modified_weight = weight + kl_penalty_completion[:, :C] + n_actions = min(action_indices.numel(), C) + if n_actions > 0: + kl_penalty_completion[b, :n_actions] = kl_penalty_full[b, action_indices[:n_actions]] + completion_mask[b, :n_actions] = 1.0 + modified_weight = (weight + kl_penalty_completion) * completion_mask # Create modified batch (shallow copy with updated weight) modified_batch = dict(batch) From 85d4f846af69b4d5229eebda050e5a26c7982735 Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 17:06:45 +0100 Subject: [PATCH 16/17] use per-token KL (fix Opus delusions) --- src/ludic/training/algorithm.py | 7 +++- src/ludic/training/credit_assignment.py | 25 ++++++++---- src/ludic/training/loss.py | 44 +++++++++++++++----- tests/test_credit_assignment.py | 53 ++++++++++++++++++++----- tests/test_loss.py | 25 ++++++++++++ 5 files changed, 124 insertions(+), 30 deletions(-) diff --git a/src/ludic/training/algorithm.py b/src/ludic/training/algorithm.py index d2f9d0f..f83a106 100644 --- a/src/ludic/training/algorithm.py +++ b/src/ludic/training/algorithm.py @@ -577,6 +577,7 @@ def make_gspo_opd( *, group_size: int, kl_coeff: float = 1.0, + kl_per_token: bool = True, group_normalize_adv: bool = True, positive_only: bool = False, clip_eps_low: float = 3e-4, @@ -602,12 +603,14 @@ def make_gspo_opd( Pipeline: 1. GroupNormalizedReturn computes task-based advantages - 2. KLCreditModifier adds negative KL to advantages + 2. KLCreditModifier adds negative KL to advantages (per-token if kl_per_token=True) 3. ClippedSurrogateLoss computes policy gradient with modified advantages Args: group_size: Number of rollouts per group for advantage normalization. kl_coeff: Coefficient for KL penalty. Higher = stronger teacher matching. + kl_per_token: If True, apply KL penalty per action token by broadcasting + scalar advantages to token-level weights. group_normalize_adv: Normalize advantages within each group. positive_only: Clip negative advantages to zero. clip_eps_low: Lower PPO clipping epsilon. @@ -646,7 +649,7 @@ def make_gspo_opd( positive_only=positive_only, ) - credit_modifiers = [KLCreditModifier(coeff=kl_coeff)] + credit_modifiers = [KLCreditModifier(coeff=kl_coeff, broadcast_advantage=kl_per_token)] loss: Loss = ClippedSurrogateLoss( clip_eps_low=clip_eps_low, diff --git a/src/ludic/training/credit_assignment.py b/src/ludic/training/credit_assignment.py index 31459ee..fc5b270 100644 --- a/src/ludic/training/credit_assignment.py +++ b/src/ludic/training/credit_assignment.py @@ -79,6 +79,8 @@ class KLCreditModifier: Args: coeff: Coefficient for KL penalty. Higher = stronger teacher matching. name: Modifier name for logging. Metrics appear as "{name}/kl_mean", etc. + broadcast_advantage: If True, expand scalar advantages to per-token + values over action tokens so KL is applied per token. Requires batch to have: - "actor_logps": [B, T] old policy logprobs from rollout @@ -94,6 +96,7 @@ class KLCreditModifier: coeff: float = 1.0 name: str = "kl" + broadcast_advantage: bool = False def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: if "actor_logps" not in batch: @@ -113,22 +116,28 @@ def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: weight = batch["weight"] # [B], [B, C], or [B, T] B, T = action_mask.shape + action_mask_f = action_mask.float() # Reverse KL: log π_student - log π_teacher (full sequence) # We want to minimize this, so we add NEGATIVE KL to advantages reverse_kl = actor_logps - teacher_logps # [B, T] kl_penalty_full = -self.coeff * reverse_kl # [B, T] - # Handle different weight shapes, keeping output in same format as input. + # Handle different weight shapes. By default we keep the input format, + # but can broadcast scalar weights to per-token advantages if requested. # The loss function expects weight to match ratio shape, which may be # completion-only [B, C] rather than full sequence [B, T]. if weight.dim() == 1: - # Turn-level [B]: add completion KL summed over action tokens. - kl_penalty_scalar = (kl_penalty_full * action_mask.float()).sum(dim=-1) - modified_weight = weight + kl_penalty_scalar + # Turn-level [B]: either keep scalar or broadcast to per-token. + if self.broadcast_advantage: + base_adv = weight.unsqueeze(-1) * action_mask_f + modified_weight = (base_adv + kl_penalty_full) * action_mask_f + else: + kl_penalty_scalar = (kl_penalty_full * action_mask_f).sum(dim=-1) + modified_weight = weight + kl_penalty_scalar elif weight.shape[-1] == T: # Full sequence [B, T]: add KL directly, mask to action tokens. - modified_weight = (weight + kl_penalty_full) * action_mask.float() + modified_weight = (weight + kl_penalty_full) * action_mask_f else: # Completion-only [B, C]: extract KL for action tokens and align to completion positions. C = weight.shape[-1] @@ -153,9 +162,9 @@ def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: # Compute metrics (masked to action tokens) mask_sum = action_mask.sum() if mask_sum > 0: - masked_kl = reverse_kl * action_mask.float() + masked_kl = reverse_kl * action_mask_f kl_mean = masked_kl.sum() / mask_sum - kl_std = ((masked_kl - kl_mean * action_mask.float()) ** 2).sum() / mask_sum + kl_std = ((masked_kl - kl_mean * action_mask_f) ** 2).sum() / mask_sum kl_std = kl_std.sqrt() else: kl_mean = reverse_kl.new_zeros(()) @@ -164,7 +173,7 @@ def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: metrics = { "kl_mean": kl_mean.detach(), "kl_std": kl_std.detach(), - "kl_penalty_mean": (kl_penalty_full * action_mask.float()).sum().detach() / mask_sum.clamp(min=1), + "kl_penalty_mean": (kl_penalty_full * action_mask_f).sum().detach() / mask_sum.clamp(min=1), } return modified_batch, metrics diff --git a/src/ludic/training/loss.py b/src/ludic/training/loss.py index ccf9104..54aa89b 100644 --- a/src/ludic/training/loss.py +++ b/src/ludic/training/loss.py @@ -339,7 +339,7 @@ class ClippedSurrogateLoss: L_clip = - E[ min(r * A, clip(r, 1 - eps_low, 1 + eps_high) * A) ] Expects: - - batch["weight"]: A (advantages) [B] + - batch["weight"]: A (advantages) [B] or token-level [B, T] - batch[old_logp_key]: log π_old(a|s) [B] - input_ids / attention_mask / action_mask for π_new. @@ -365,7 +365,7 @@ def __post_init__(self) -> None: def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]]: input_ids = batch["input_ids"] action_mask = batch["action_mask"] - advantages = batch["weight"] # [B] + advantages = batch["weight"] if self.old_logp_key not in batch: raise KeyError(f"ClippedSurrogateLoss requires '{self.old_logp_key}' in batch.") @@ -387,13 +387,28 @@ def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]] if self.ratio_clip is not None: ratio = torch.clamp(ratio, max=self.ratio_clip) - unclipped = ratio * advantages - clipped = torch.clamp( + ratio_clipped = torch.clamp( ratio, 1.0 - self.clip_eps_low, 1.0 + self.clip_eps_high - ) * advantages - - obj = torch.min(unclipped, clipped) - loss = -obj.mean() + ) + if advantages.dim() == 2: + if advantages.shape != action_mask.shape: + raise ValueError( + "ClippedSurrogateLoss expects token-level weights to match action_mask shape " + f"{tuple(action_mask.shape)}, got {tuple(advantages.shape)}." + ) + adv_mask = action_mask.to(advantages.dtype) + ratio_exp = ratio.unsqueeze(-1) + ratio_clipped_exp = ratio_clipped.unsqueeze(-1) + unclipped = ratio_exp * advantages + clipped = ratio_clipped_exp * advantages + obj = torch.min(unclipped, clipped) * adv_mask + adv_denom = adv_mask.sum().clamp(min=1.0) + loss = -obj.sum() / adv_denom + else: + unclipped = ratio * advantages + clipped = ratio_clipped * advantages + obj = torch.min(unclipped, clipped) + loss = -obj.mean() ppo_clip_frac = ( (ratio > 1.0 + self.clip_eps_high) | (ratio < 1.0 - self.clip_eps_low) @@ -403,6 +418,15 @@ def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]] else: ratio_clip_frac = torch.zeros((), device=ratio.device, dtype=ratio.dtype) + if advantages.dim() == 2: + masked_adv = advantages * adv_mask + adv_mean = masked_adv.sum() / adv_denom + adv_var = ((masked_adv - adv_mean) * adv_mask).pow(2).sum() / adv_denom + adv_std = adv_var.sqrt() + else: + adv_mean = advantages.mean() + adv_std = advantages.std(unbiased=False) + stats = { "loss": loss.detach(), "ratio_mean": ratio.mean().detach(), @@ -410,8 +434,8 @@ def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]] "clip_frac": ppo_clip_frac.detach(), "ratio_clip_frac": ratio_clip_frac.detach(), "kl_actor_policy": mismatch_kl.mean().detach(), - "adv_mean": advantages.mean().detach(), - "adv_std": advantages.std(unbiased=False).detach(), + "adv_mean": adv_mean.detach(), + "adv_std": adv_std.detach(), "logp_mean": logp_action.mean().detach(), } return loss, stats diff --git a/tests/test_credit_assignment.py b/tests/test_credit_assignment.py index 0f51469..78d1216 100644 --- a/tests/test_credit_assignment.py +++ b/tests/test_credit_assignment.py @@ -443,11 +443,11 @@ def test_kl_credit_modifier_zeroes_prompt_tokens(): def test_kl_credit_modifier_turn_level_weight(): """ - Test that turn-level (1D) weights are broadcast to completion-only format. + Test that turn-level (1D) weights stay scalar per sample by default. Credit assigners produce one weight per turn/step, but KL is per-token. - The modifier broadcasts the turn weight and adds per-token KL, outputting - completion-only [B, max_completion_len] to match ratio shape in loss. + The modifier keeps the scalar shape and adds the summed per-token KL + over action tokens. """ # Turn-level weight: one value per sample (shape [B]) weight = torch.tensor([2.0, -1.0]) # two samples with advantages 2.0 and -1.0 @@ -476,16 +476,49 @@ def test_kl_credit_modifier_turn_level_weight(): modifier = KLCreditModifier(coeff=1.0) modified_batch, _ = modifier.modify(batch) - # Output is completion-only [B, max_completion_len=3] - # Sample 0: weight=2.0 + kl_penalty=[-0.5, -0.5, -0.5] -> [1.5, 1.5, 1.5] - # Sample 1: weight=-1.0 + kl_penalty=[0.5, 0.5, 0.0] -> [-0.5, -0.5, 0.0] - # (only 2 action tokens, so 3rd position is 0 padding) + # Output stays [B] with summed KL over action tokens. + # Sample 0: weight=2.0 + sum(kl_penalty=-0.5 * 3 tokens) = 0.5 + # Sample 1: weight=-1.0 + sum(kl_penalty=0.5 * 2 tokens) = 0.0 + expected_weight = torch.tensor([0.5, 0.0]) + + assert modified_batch["weight"].shape == (2,) + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) + + +def test_kl_credit_modifier_turn_level_weight_broadcast(): + """ + Test that turn-level (1D) weights can be broadcast to per-token advantages. + """ + weight = torch.tensor([2.0, -1.0]) + actor_logps = torch.tensor([ + [-1.0, -1.0, -1.0, -1.0], + [-2.0, -2.0, -2.0, -2.0], + ]) + teacher_logps = torch.tensor([ + [-1.5, -1.5, -1.5, -1.5], + [-1.5, -1.5, -1.5, -1.5], + ]) + action_mask = torch.tensor([ + [0, 1, 1, 1], + [0, 0, 1, 1], + ]) + + batch = { + "weight": weight, + "actor_logps": actor_logps, + "teacher_logps": teacher_logps, + "action_mask": action_mask, + } + + modifier = KLCreditModifier(coeff=1.0, broadcast_advantage=True) + modified_batch, _ = modifier.modify(batch) + expected_weight = torch.tensor([ - [1.5, 1.5, 1.5], - [-0.5, -0.5, 0.0], # 3rd is padding (no action token there) + [0.0, 1.5, 1.5, 1.5], + [0.0, 0.0, -0.5, -0.5], ]) - assert modified_batch["weight"].shape == (2, 3) # completion-only + assert modified_batch["weight"].shape == (2, 4) assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) diff --git a/tests/test_loss.py b/tests/test_loss.py index 6832009..078d7ae 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -257,6 +257,31 @@ def test_gspo_loss_length_normalize_affects_ratio(): assert torch.allclose(loss_norm, expected_loss_norm, atol=1e-4) +def test_gspo_loss_token_advantages_masked_mean(): + logits = torch.tensor([[ + [0.0, 0.0], # predicts token at pos 1 + [2.0, 0.0], # predicts token at pos 2 + [0.0, 0.0], # unused + ]], dtype=torch.float32) + input_ids = torch.tensor([[0, 1, 0]], dtype=torch.long) + action_mask = torch.tensor([[0, 1, 1]], dtype=torch.float32) + + logp_action = compute_logp_action(logits, input_ids, action_mask) + batch = { + "input_ids": input_ids, + "action_mask": action_mask, + "weight": torch.tensor([[0.0, 2.0, -1.0]], dtype=torch.float32), + "old_logp_action": logp_action, + } + + loss_fn = ClippedSurrogateLoss(clip_eps_low=1.0, clip_eps_high=1.0, length_normalize=False) + loss, _ = loss_fn.compute(logits, batch) + + # ratio = 1.0, so loss is the masked mean of advantages over action tokens. + # (2.0 + -1.0) / 2 = 0.5 => loss = -0.5 + assert torch.allclose(loss, torch.tensor(-0.5), atol=1e-4) + + def test_gspo_loss_upper_clip_positive_advantage(): logits = torch.tensor([[[0.0, 0.0], [0.0, 0.0]]], dtype=torch.float32) input_ids = torch.tensor([[0, 1]], dtype=torch.long) From aae5e6676fdcd67946483ff896616656a5c5387b Mon Sep 17 00:00:00 2001 From: hallerite Date: Mon, 29 Dec 2025 18:15:21 +0100 Subject: [PATCH 17/17] make per_token_kl the default --- examples/opd/README.md | 13 ++++--- src/ludic/training/algorithm.py | 7 ++-- src/ludic/training/credit_assignment.py | 17 +++------- tests/test_credit_assignment.py | 45 +++---------------------- 4 files changed, 18 insertions(+), 64 deletions(-) diff --git a/examples/opd/README.md b/examples/opd/README.md index dadb01b..83ef09b 100644 --- a/examples/opd/README.md +++ b/examples/opd/README.md @@ -6,9 +6,9 @@ This hybrid approach combines: - **GSPO (Group-Sorted Policy Optimization)**: Task rewards from GSM8K correctness with group-normalized advantages - **OPD (On-Policy Distillation)**: Dense per-token feedback via reverse KL divergence from teacher -The composite loss balances: -1. **Task-specific learning**: Sparse but grounded rewards from environment -2. **Distribution matching**: Dense per-token guidance from teacher +The hybrid adds KL penalty directly to advantages (Level 2: Advantage Modification): +1. **Task-specific learning**: Sparse but grounded rewards from environment → group-normalized advantages +2. **Distribution matching**: Dense per-token KL penalty added to advantages Reference: https://thinkingmachines.ai/blog/on-policy-distillation @@ -84,10 +84,9 @@ CUDA_VISIBLE_DEVICES=1 PYTHONPATH=. uv run python examples/opd/train_opd_gsm8k.p ### Training logs Output includes: -- `train/loss`: Combined loss (GSPO + KL) -- `train/gspo/loss`: GSPO policy gradient loss -- `train/kl/loss`: Reverse KL loss -- `train/kl/reverse_kl_mean`: Mean per-token KL divergence +- `train/loss`: Policy gradient loss with KL-modified advantages +- `train/kl/kl_mean`: Mean per-token reverse KL (actor - teacher logprobs) +- `train/kl/kl_penalty_mean`: Mean KL penalty added to advantages - `train/correct_rate`: GSM8K accuracy on training samples - `train/avg_completion_length`: Average tokens per completion - `eval/accuracy`: GSM8K accuracy on test set diff --git a/src/ludic/training/algorithm.py b/src/ludic/training/algorithm.py index f83a106..d2f9d0f 100644 --- a/src/ludic/training/algorithm.py +++ b/src/ludic/training/algorithm.py @@ -577,7 +577,6 @@ def make_gspo_opd( *, group_size: int, kl_coeff: float = 1.0, - kl_per_token: bool = True, group_normalize_adv: bool = True, positive_only: bool = False, clip_eps_low: float = 3e-4, @@ -603,14 +602,12 @@ def make_gspo_opd( Pipeline: 1. GroupNormalizedReturn computes task-based advantages - 2. KLCreditModifier adds negative KL to advantages (per-token if kl_per_token=True) + 2. KLCreditModifier adds negative KL to advantages 3. ClippedSurrogateLoss computes policy gradient with modified advantages Args: group_size: Number of rollouts per group for advantage normalization. kl_coeff: Coefficient for KL penalty. Higher = stronger teacher matching. - kl_per_token: If True, apply KL penalty per action token by broadcasting - scalar advantages to token-level weights. group_normalize_adv: Normalize advantages within each group. positive_only: Clip negative advantages to zero. clip_eps_low: Lower PPO clipping epsilon. @@ -649,7 +646,7 @@ def make_gspo_opd( positive_only=positive_only, ) - credit_modifiers = [KLCreditModifier(coeff=kl_coeff, broadcast_advantage=kl_per_token)] + credit_modifiers = [KLCreditModifier(coeff=kl_coeff)] loss: Loss = ClippedSurrogateLoss( clip_eps_low=clip_eps_low, diff --git a/src/ludic/training/credit_assignment.py b/src/ludic/training/credit_assignment.py index fc5b270..5d157b0 100644 --- a/src/ludic/training/credit_assignment.py +++ b/src/ludic/training/credit_assignment.py @@ -79,8 +79,6 @@ class KLCreditModifier: Args: coeff: Coefficient for KL penalty. Higher = stronger teacher matching. name: Modifier name for logging. Metrics appear as "{name}/kl_mean", etc. - broadcast_advantage: If True, expand scalar advantages to per-token - values over action tokens so KL is applied per token. Requires batch to have: - "actor_logps": [B, T] old policy logprobs from rollout @@ -96,7 +94,6 @@ class KLCreditModifier: coeff: float = 1.0 name: str = "kl" - broadcast_advantage: bool = False def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: if "actor_logps" not in batch: @@ -123,18 +120,14 @@ def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: reverse_kl = actor_logps - teacher_logps # [B, T] kl_penalty_full = -self.coeff * reverse_kl # [B, T] - # Handle different weight shapes. By default we keep the input format, - # but can broadcast scalar weights to per-token advantages if requested. + # Handle different weight shapes. Scalar weights are broadcast to + # per-token advantages so KL is applied per action token. # The loss function expects weight to match ratio shape, which may be # completion-only [B, C] rather than full sequence [B, T]. if weight.dim() == 1: - # Turn-level [B]: either keep scalar or broadcast to per-token. - if self.broadcast_advantage: - base_adv = weight.unsqueeze(-1) * action_mask_f - modified_weight = (base_adv + kl_penalty_full) * action_mask_f - else: - kl_penalty_scalar = (kl_penalty_full * action_mask_f).sum(dim=-1) - modified_weight = weight + kl_penalty_scalar + # Turn-level [B]: broadcast to per-token advantages. + base_adv = weight.unsqueeze(-1) * action_mask_f + modified_weight = (base_adv + kl_penalty_full) * action_mask_f elif weight.shape[-1] == T: # Full sequence [B, T]: add KL directly, mask to action tokens. modified_weight = (weight + kl_penalty_full) * action_mask_f diff --git a/tests/test_credit_assignment.py b/tests/test_credit_assignment.py index 78d1216..ffd4cb4 100644 --- a/tests/test_credit_assignment.py +++ b/tests/test_credit_assignment.py @@ -443,11 +443,10 @@ def test_kl_credit_modifier_zeroes_prompt_tokens(): def test_kl_credit_modifier_turn_level_weight(): """ - Test that turn-level (1D) weights stay scalar per sample by default. + Test that turn-level (1D) weights are broadcast to per-token advantages. Credit assigners produce one weight per turn/step, but KL is per-token. - The modifier keeps the scalar shape and adds the summed per-token KL - over action tokens. + The modifier broadcasts the scalar weight and adds per-token KL. """ # Turn-level weight: one value per sample (shape [B]) weight = torch.tensor([2.0, -1.0]) # two samples with advantages 2.0 and -1.0 @@ -476,43 +475,9 @@ def test_kl_credit_modifier_turn_level_weight(): modifier = KLCreditModifier(coeff=1.0) modified_batch, _ = modifier.modify(batch) - # Output stays [B] with summed KL over action tokens. - # Sample 0: weight=2.0 + sum(kl_penalty=-0.5 * 3 tokens) = 0.5 - # Sample 1: weight=-1.0 + sum(kl_penalty=0.5 * 2 tokens) = 0.0 - expected_weight = torch.tensor([0.5, 0.0]) - - assert modified_batch["weight"].shape == (2,) - assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) - - -def test_kl_credit_modifier_turn_level_weight_broadcast(): - """ - Test that turn-level (1D) weights can be broadcast to per-token advantages. - """ - weight = torch.tensor([2.0, -1.0]) - actor_logps = torch.tensor([ - [-1.0, -1.0, -1.0, -1.0], - [-2.0, -2.0, -2.0, -2.0], - ]) - teacher_logps = torch.tensor([ - [-1.5, -1.5, -1.5, -1.5], - [-1.5, -1.5, -1.5, -1.5], - ]) - action_mask = torch.tensor([ - [0, 1, 1, 1], - [0, 0, 1, 1], - ]) - - batch = { - "weight": weight, - "actor_logps": actor_logps, - "teacher_logps": teacher_logps, - "action_mask": action_mask, - } - - modifier = KLCreditModifier(coeff=1.0, broadcast_advantage=True) - modified_batch, _ = modifier.modify(batch) - + # Output is per-token [B, T] with masked prompt positions. + # Sample 0: weight=2.0 + kl_penalty=[-0.5, -0.5, -0.5] -> [1.5, 1.5, 1.5] + # Sample 1: weight=-1.0 + kl_penalty=[0.5, 0.5] -> [-0.5, -0.5] expected_weight = torch.tensor([ [0.0, 1.5, 1.5, 1.5], [0.0, 0.0, -0.5, -0.5],