diff --git a/docs/composition.md b/docs/composition.md new file mode 100644 index 0000000..4bd6ea0 --- /dev/null +++ b/docs/composition.md @@ -0,0 +1,163 @@ +# Signal Composition in Ludic + +This document describes how different training signals (rewards, advantages, losses) can be composed in Ludic's RL training pipeline. + +## The Training Pipeline + +``` +Environment ──► Rewards ──► CreditAssigner ──► Advantages ──► Loss + │ │ + ▼ ▼ + Scorers CreditModifiers + (Level 1) (Level 2) +``` + +There are **three composition levels**: + +| Level | Name | Where | Implementation | +|-------|------|-------|----------------| +| **1** | Reward | Before credit assignment | Agent scorers | +| **2** | Advantage | After credit assignment, before loss | CreditModifier | +| **3** | Loss | Separate loss terms | CompositeLoss | + +## Level 1: Reward Composition + +Add signals to rewards via Agent scorers, before credit assignment. + +```python +# Scorers attached to Agent add to per-step rewards +agent = Agent( + client=client, + scorers=[intrinsic_reward_scorer], # adds to step rewards +) +``` + +**Properties:** +- All signals go through the same credit assignment +- Signals interact (e.g., group normalization affects combined rewards) +- Tightest coupling between signals + +**Use when:** +- Intrinsic rewards should be treated identically to environment rewards +- You want signals to interact during advantage estimation + +## Level 2: Advantage Modification + +Modify advantages after credit assignment, before loss. + +```python +# KL penalty added to advantages +kl_penalty = -kl_coeff * (actor_logps - teacher_logps) +advantages = task_advantages + kl_penalty +# Then normal policy gradient with combined advantages +``` + +**Properties:** +- Each signal can have its own credit assignment strategy +- All signals go through the same importance ratio +- All signals go through the same loss function + +**Use when:** +- Different signals need different credit assignment (e.g., sparse task rewards vs dense KL) +- You want all signals to go through importance sampling together + +**Implementation in Ludic:** + +Use `CreditModifier` to add per-token signals to advantages: + +```python +algo = RLAlgorithm( + credit_assigner=GroupNormalizedReturn(group_size=8), + credit_modifiers=[KLCreditModifier(coeff=1.0)], + loss=ClippedSurrogateLoss(...), +) +``` + +Or use the preset: + +```python +algo = make_gspo_opd(group_size=8, kl_coeff=1.0) +``` + +## Level 3: Loss Composition + +Combine independent loss terms additively. + +```python +loss = CompositeLoss(terms=[ + LossTerm(name="rl", loss=ClippedSurrogateLoss(...), weight=1.0), + LossTerm(name="auxiliary", loss=SomeAuxiliaryLoss(...), weight=0.1), +]) +``` + +**Properties:** +- Each loss computed independently +- Different losses can use different data (current vs old policy logprobs) +- Loosest coupling +- Most flexible but signals don't interact + +**Use when:** +- Truly independent objectives (e.g., RL + language modeling auxiliary) +- Different losses need fundamentally different handling +- You need maximum flexibility + +## Key Differences: Advantage vs Loss Composition + +| Aspect | Advantage Modification (Level 2) | Loss Composition (Level 3) | +|--------|----------------------------------|---------------------------| +| **KL source** | Old policy (rollout time) | Current policy (forward pass) | +| **Importance sampling** | Goes through ratio | Doesn't go through ratio | +| **Gradient** | `ratio * (task_adv + kl_penalty)` | `ratio * task_adv + kl_grad` | +| **Interaction** | Signals combined | Signals independent | + +### Mathematical Difference + +**Advantage Modification:** +``` +A_t = task_advantage - kl_coeff * KL_old_t +L = E[ ratio_t * A_t ] +∇L ∝ ratio_t * A_t * ∇log π_t +``` + +**Loss Composition:** +``` +L = E[ ratio_t * task_advantage ] + kl_coeff * E[ KL_current_t ] +∇L ∝ ratio_t * task_advantage * ∇log π_t + kl_coeff * ∇log π_t +``` + +In synchronous RL with single gradient step, `ratio ≈ 1` and `KL_old ≈ KL_current`, so these are similar. They diverge with: +- Multiple epochs per batch (PPO-style) +- Async RL with stale rollouts +- Large policy updates + +## Recommended Patterns + +### Pattern 1: Pure RL (task rewards only) +```python +algo = make_gspo(group_size=8) +``` + +### Pattern 2: GSPO + OPD hybrid (recommended for distillation) +```python +algo = make_gspo_opd(group_size=8, kl_coeff=1.0) +``` + +### Pattern 3: Independent auxiliary loss +```python +algo = RLAlgorithm( + credit_assigner=GroupNormalizedReturn(group_size=8), + loss=CompositeLoss(terms=[ + LossTerm(name="rl", loss=ClippedSurrogateLoss(...), weight=1.0), + LossTerm(name="lm", loss=LanguageModelingLoss(...), weight=0.1), + ]), +) +``` + +## Summary + +| Scenario | Level | Implementation | +|----------|-------|----------------| +| Pure RL | - | `make_gspo()` | +| RL + teacher distillation | 2 (Advantage) | `make_gspo_opd()` | +| RL + unrelated auxiliary | 3 (Loss) | `CompositeLoss` | +| Intrinsic rewards | 1 (Reward) | Agent scorers | diff --git a/examples/opd/README.md b/examples/opd/README.md new file mode 100644 index 0000000..83ef09b --- /dev/null +++ b/examples/opd/README.md @@ -0,0 +1,121 @@ +# GSPO + OPD Hybrid Training on GSM8K + +Train a smaller student model using both task rewards and dense per-token supervision from a larger teacher model. + +This hybrid approach combines: +- **GSPO (Group-Sorted Policy Optimization)**: Task rewards from GSM8K correctness with group-normalized advantages +- **OPD (On-Policy Distillation)**: Dense per-token feedback via reverse KL divergence from teacher + +The hybrid adds KL penalty directly to advantages (Level 2: Advantage Modification): +1. **Task-specific learning**: Sparse but grounded rewards from environment → group-normalized advantages +2. **Distribution matching**: Dense per-token KL penalty added to advantages + +Reference: https://thinkingmachines.ai/blog/on-policy-distillation + +## Prerequisites + +- At least 2 GPUs (e.g., 2x A100). + - GPU 0: Both vLLM servers (student 0.5B + teacher 7B fit together) + - GPU 1: Training (gradient updates) +- Required extra packages: `datasets`, `math-verify`. + +Install deps (once): +```bash +uv sync --extra examples +``` + +## 1) Start vLLM servers + +You need **two** vLLM servers: one for the student (sampling) and one for the teacher (scoring). For these small models, both can share GPU 0. + +**Important**: Student and teacher must use the **same tokenizer**. The Qwen2.5 family shares tokenizers across sizes, so this works. + +### Terminal 1: Student server (port 8000) +```bash +CUDA_VISIBLE_DEVICES=0 uv run python -m ludic.inference.vllm_server \ + --model Qwen/Qwen2.5-0.5B-Instruct \ + --port 8000 \ + --gpu-memory-utilization 0.4 +``` + +### Terminal 2: Teacher server (port 8001) +```bash +CUDA_VISIBLE_DEVICES=0 uv run python -m ludic.inference.vllm_server \ + --model Qwen/Qwen2.5-7B-Instruct \ + --port 8001 \ + --gpu-memory-utilization 0.5 +``` + +Wait for both servers to report ready before proceeding. + +## 2) Train with OPD + +In a third terminal, run the OPD training script on GPU 1: +```bash +CUDA_VISIBLE_DEVICES=1 PYTHONPATH=. uv run python examples/opd/train_opd_gsm8k.py \ + --student-model Qwen/Qwen2.5-0.5B-Instruct \ + --teacher-model Qwen/Qwen2.5-7B-Instruct \ + --student-port 8000 \ + --teacher-port 8001 \ + --rollouts-per-update 64 \ + --train-steps 100 \ + --micro-token-budget 16384 \ + --max-seq-len 1024 +``` + +### Key flags + +| Flag | Default | Description | +|------|---------|-------------| +| `--student-model` | `Qwen/Qwen2.5-0.5B-Instruct` | Student model (must match vLLM server) | +| `--teacher-model` | `Qwen/Qwen2.5-7B-Instruct` | Teacher model (must share tokenizer with student) | +| `--student-port` | 8000 | Student vLLM server port | +| `--teacher-port` | 8001 | Teacher vLLM server port | +| `--kl-coeff` | 1.0 | Coefficient for reverse KL loss term | +| `--rollouts-per-update` | 256 | Total rollouts per training step | +| `--group-size` | 8 | Group size for GSPO advantages | +| `--concurrency` | 32 | Parallel rollout generation | +| `--limit` | None | Limit training samples (None = use all) | +| `--logger` | `rich` | Loggers: rich, print, wandb, none (comma-separated) | +| `--eval-every` | 10 | Eval every N train steps | +| `--eval-limit` | 1000 | Number of test samples for eval | +| `--eval-temperature` | 0.0 | Sampling temperature for eval (greedy) | + +### Training logs + +Output includes: +- `train/loss`: Policy gradient loss with KL-modified advantages +- `train/kl/kl_mean`: Mean per-token reverse KL (actor - teacher logprobs) +- `train/kl/kl_penalty_mean`: Mean KL penalty added to advantages +- `train/correct_rate`: GSM8K accuracy on training samples +- `train/avg_completion_length`: Average tokens per completion +- `eval/accuracy`: GSM8K accuracy on test set +- `eval/parse_error_rate`: Parse error rate on test set + +Rollouts are written to `opd_rollouts.jsonl`. + +## How GSPO + OPD works + +This uses "Level 2: Advantage Modification" via CreditModifier (see `docs/composition.md`): + +1. **Student samples**: The student model generates completions for GSM8K problems +2. **Environment rewards**: Each completion is graded for correctness (sparse reward) +3. **Teacher scores**: The teacher model computes per-token logprobs on the student's samples +4. **Credit assignment**: GroupNormalizedReturn computes task-based advantages +5. **Credit modification**: KLCreditModifier adds KL penalty to advantages: + ``` + A_t = task_advantage + (-kl_coeff * (actor_logp_t - teacher_logp_t)) + ``` +6. **Policy gradient**: ClippedSurrogateLoss with modified advantages + +Key benefits of this approach (vs CompositeLoss): +- KL goes through importance sampling (multiplied by ratio like task rewards) +- KL uses old policy logprobs from rollout time (not current policy) +- All signals interact through the same loss function + +## Tips + +- **Same tokenizer is required**: OPD passes token IDs directly from student to teacher. If tokenizers differ, results will be meaningless. +- **Context window**: Ensure prompt + completion fits in teacher's context window. Truncation causes length mismatches. +- **GPU memory**: With larger models, you may need separate GPUs for student and teacher. Adjust `--gpu-memory-utilization` accordingly. +- **KL coefficient**: Start with `--kl-coeff 1.0`. Increase if student diverges too much from teacher. diff --git a/examples/opd/train_opd_gsm8k.py b/examples/opd/train_opd_gsm8k.py new file mode 100644 index 0000000..685e0c6 --- /dev/null +++ b/examples/opd/train_opd_gsm8k.py @@ -0,0 +1,442 @@ +""" +GSPO + OPD hybrid training on GSM8K using vLLM. + +This example combines: + - GSPO (Group-Sorted Policy Optimization): Task rewards from GSM8K correctness + - OPD (On-Policy Distillation): Dense per-token supervision from teacher + +The hybrid uses "Level 2: Advantage Modification" via CreditModifier: + 1. GroupNormalizedReturn computes task-based advantages from correctness rewards + 2. KLCreditModifier adds negative KL to advantages: A_t += -coeff * (actor_logp - teacher_logp) + 3. ClippedSurrogateLoss computes policy gradient with modified advantages + +Key benefits of this approach (vs CompositeLoss): + - KL goes through importance sampling (like task rewards) + - KL uses old policy logprobs from rollout time + - All signals interact through the same loss function + +See docs/composition.md for the full composition level documentation. + +The key insight: teacher logprobs are an *intrinsic scorer* attached to the Agent. +The scorer runs during Agent.act() and scores flow through to training. + +Reference: https://thinkingmachines.ai/blog/on-policy-distillation + +Usage: + python train_opd_gsm8k.py \ + --student-model Qwen/Qwen2.5-0.5B-Instruct \ + --teacher-model Qwen/Qwen2.5-7B-Instruct \ + --limit 1000 + +Requirements: + - vLLM servers running for student (port 8000) and teacher (port 8001) +""" + +from __future__ import annotations + +import argparse +import os +import sys +import queue +from typing import List, Dict, Any + +import torch +from datasets import load_dataset # type: ignore +from transformers import AutoModelForCausalLM, AutoTokenizer + +from ludic.agent import Agent +from ludic.context import FullDialog +from ludic.inference import VLLMChatClient, InferenceSpec, SamplingParams, ReturnSpec, HFChatTemplate +from ludic.interaction import SingleAgentSyncProtocol +from ludic.parsers import boxed_parser +from ludic.distributed.adapters import create_vllm_publisher +from ludic.eval import EngineEvaluator +from ludic.training import ( + RolloutEngine, + RolloutBatchSource, + Trainer, + TrainerConfig, + CheckpointConfig, + make_dataset_queue_requests_fn, + RequestsExhausted, + RolloutRequest, + EnvSpec, + ProtocolSpec, + make_gspo_opd, +) +from ludic.training import Reducer, RichLiveLogger, PrintLogger, TeeLogger, WandbLogger, default_reducers +from ludic.training.scoring import make_vllm_teacher_scorer +from environments.gsm8k import GSM8KEnv + + +def load_gsm8k(split: str, limit: int | None) -> List[Dict[str, Any]]: + """Load GSM8K dataset samples.""" + ds = load_dataset("gsm8k", "main", split=split) + samples: List[Dict[str, Any]] = [] + for idx, row in enumerate(ds): + samples.append( + { + "question": row["question"], + "answer": row["answer"], + "id": row.get("id", idx), + } + ) + if limit is not None and len(samples) >= limit: + break + return samples + + +def main(): + parser = argparse.ArgumentParser(description="OPD training on GSM8K") + + # Model configuration + parser.add_argument("--student-model", default="Qwen/Qwen2.5-0.5B-Instruct", + help="Student model name/path") + parser.add_argument("--teacher-model", default="Qwen/Qwen2.5-7B-Instruct", + help="Teacher model name/path") + + # vLLM server configuration + parser.add_argument("--student-host", default="127.0.0.1") + parser.add_argument("--student-port", type=int, default=8000) + parser.add_argument("--teacher-host", default="127.0.0.1") + parser.add_argument("--teacher-port", type=int, default=8001) + + # Data configuration + parser.add_argument("--split", default="train") + parser.add_argument("--limit", type=int, default=None, + help="Limit training samples (None = use all)") + + # Training configuration + parser.add_argument("--rollouts-per-update", type=int, default=256, + help="Total rollouts per update (must be divisible by --group-size)") + parser.add_argument("--group-size", type=int, default=8, + help="Group size for grouped advantages") + parser.add_argument("--train-steps", type=int, default=20, + help="Number of training steps; 0 = run until samples exhausted") + parser.add_argument("--max-seq-len", type=int, default=1024, + help="Max tokens per sample") + parser.add_argument("--micro-token-budget", type=int, default=16384, + help="Max padded tokens per micro-batch") + parser.add_argument("--max-completion-tokens", type=int, default=512, + help="Max completion tokens per rollout") + parser.add_argument("--train-temperature", type=float, default=1.0, + help="Sampling temperature for training rollouts") + parser.add_argument("--concurrency", type=int, default=64, + help="Rollout concurrency") + + # OPD-specific configuration (hybrid GSPO + KL) + parser.add_argument("--kl-coeff", type=float, default=1.0, + help="Coefficient for reverse KL loss term") + + # System prompt + parser.add_argument("--system-prompt", type=str, + default="First, think step by step. Then put your final answer inside \\boxed{...}.", + help="System prompt for GSM8K env; set to '' to use the model default.") + + # Logging + parser.add_argument("--rollout-log", type=str, default="opd_rollouts.jsonl") + parser.add_argument("--logger", type=str, default="rich", + help="Comma-separated loggers: rich, print, wandb, none.") + + # Evaluation + parser.add_argument("--eval-every", type=int, default=10, + help="Eval every N train steps.") + parser.add_argument("--eval-before-start", action="store_true", default=True, + help="Run eval once before training begins.") + parser.add_argument("--eval-limit", type=int, default=1000, + help="Number of test samples for eval.") + parser.add_argument("--eval-concurrency", type=int, default=64) + parser.add_argument("--eval-temperature", type=float, default=0.0, + help="Sampling temperature for eval passes.") + + # Checkpointing + parser.add_argument("--final-save", action="store_true", + help="Save a final checkpoint after training completes.") + + args = parser.parse_args() + + # Validation + if args.rollouts_per_update <= 0: + raise ValueError("--rollouts-per-update must be > 0.") + if args.rollouts_per_update % args.group_size != 0: + raise ValueError("--rollouts-per-update must be divisible by --group-size.") + if args.max_completion_tokens > args.max_seq_len: + raise ValueError("--max-completion-tokens must be <= --max-seq-len.") + + # Setup rollout log path + rollout_log_path = os.path.abspath(args.rollout_log) + os.makedirs(os.path.dirname(rollout_log_path) or ".", exist_ok=True) + # Touch the file so tailing works even before the first rollout is written + open(rollout_log_path, "a", encoding="utf-8").close() + + # Load training data + print(f"Loading GSM8K {args.split} split...") + train_samples = load_gsm8k(args.split, args.limit) + if not train_samples: + raise SystemExit("No GSM8K samples loaded.") + print(f"Loaded {len(train_samples)} training samples") + + # Load eval data + eval_samples = load_gsm8k("test", args.eval_limit) if args.eval_limit else [] + if eval_samples: + print(f"Loaded {len(eval_samples)} eval samples") + + # Create sample queue + samples_q: queue.Queue = queue.Queue() + for idx, s in enumerate(train_samples): + samples_q.put((idx, s)) + + # Load tokenizer and model + print(f"Loading student model: {args.student_model}") + tokenizer = AutoTokenizer.from_pretrained(args.student_model) + model = AutoModelForCausalLM.from_pretrained(args.student_model, dtype=torch.bfloat16) + model.to("cuda" if torch.cuda.is_available() else "cpu") + + # Create vLLM client for student + client = VLLMChatClient( + host=args.student_host, + port=args.student_port, + enable_weight_updates=True, + ) + publisher = create_vllm_publisher(client) + chat_template = HFChatTemplate(tokenizer) + + # Teacher scorer - computes per-token logprobs during Agent.act() + teacher_scorer = make_vllm_teacher_scorer( + base_url=f"http://{args.teacher_host}:{args.teacher_port}", + model=args.teacher_model, + ) + + # Registries + env_registry = { + "gsm8k": lambda sample: GSM8KEnv(sample=sample, system_prompt=args.system_prompt) + } + + def protocol_factory(): + return SingleAgentSyncProtocol( + agent=Agent( + client=client, + model=args.student_model, + ctx=FullDialog(), + parser=boxed_parser, + chat_template=chat_template, + scorers=[teacher_scorer], # OPD: teacher provides per-token scores + ) + ) + + protocol_registry = {"single_agent": protocol_factory} + + # Algorithm: GSPO + OPD hybrid (Advantage Composition) + # - GSPO: Task rewards with group-normalized advantages + # - OPD: KL penalty added to advantages via KLCreditModifier + # See docs/composition.md for composition level documentation + algo = make_gspo_opd( + group_size=args.group_size, + kl_coeff=args.kl_coeff, + ) + + # Engine + batch source + engine = RolloutEngine( + env_registry=env_registry, + protocol_registry=protocol_registry, + jsonl_path=rollout_log_path, + ) + train_inference = InferenceSpec( + sampling=SamplingParams( + temperature=args.train_temperature, + max_tokens=args.max_completion_tokens, + ), + # Ask vLLM for token IDs + chosen-token logprobs for importance sampling + return_=ReturnSpec.for_rl(top_logprobs_k=1), + ) + base_requests = args.rollouts_per_update // args.group_size + requests_fn = make_dataset_queue_requests_fn( + samples_q, + batch_size=base_requests, + env_kind="gsm8k", + protocol_kind="single_agent", + inference=train_inference, + protocol_kwargs={}, + request_meta_fn=lambda idx, sample: { + "sample_index": idx, + "question_id": sample.get("id", idx), + }, + env_seed_fn=lambda idx, _sample: idx, + sampling_seed_fn=lambda idx, _sample: idx, + group_size=args.group_size, + ) + batch_source = RolloutBatchSource( + orchestrator=engine, + credit_assigner=algo.credit_assigner, + requests_fn=requests_fn, + max_steps=1, + concurrency=args.concurrency, + ) + + # Trainer config + cfg = TrainerConfig( + model_device="cuda" if torch.cuda.is_available() else "cpu", + max_seq_len=args.max_seq_len, + micro_token_budget=args.micro_token_budget, + max_grad_norm=0.5, + pad_token_id=tokenizer, + eval_at_start=bool(args.eval_before_start and eval_samples), + eval_every_n_steps=(args.eval_every if args.eval_every and args.eval_every > 0 and eval_samples else None), + eval_concurrency=args.eval_concurrency, + eval_max_steps=1, + ) + + # Checkpoint config + checkpoint_cfg = CheckpointConfig( + output_dir="checkpoints_opd", + every_n_steps=25, + max_to_keep=2, + save_optimizer=True, + ) + + # Reducers + reducers = { + "correct_rate": Reducer( + kind="count_true", + source="correct", + normalize_by="rollouts", + ), + "parse_err_rate": Reducer( + kind="count_true", + source="parse_error", + normalize_by="samples", + ), + "total_completion_tokens": Reducer( + kind="sum", + source="completion_length", + ), + } + reducers = {**default_reducers(), **reducers} + + # Logger keys + logger_keys = [ + "train/loss", + "train/kl/kl_mean", # KLCreditModifier: mean reverse KL per token + "train/kl/kl_penalty_mean", # KLCreditModifier: mean penalty added to advantages + "train/avg_total_reward", + "train/correct_rate", + "train/parse_err_rate", + "train/avg_completion_length", + "train/total_completion_tokens", + "eval/accuracy", + "eval/parse_error_rate", + "eval/avg_completion_tokens", + "train/target_rollouts", + "train/num_samples", + ] + + train_logger = None + raw_logger = args.logger or "rich" + logger_tokens = [tok.strip().lower() for tok in raw_logger.replace("+", ",").split(",") if tok.strip()] + valid_loggers = {"rich", "print", "wandb", "none"} + unknown = [tok for tok in logger_tokens if tok not in valid_loggers] + if unknown: + raise SystemExit(f"Unknown logger(s): {unknown}. Valid: {sorted(valid_loggers)}") + if "none" in logger_tokens: + logger_tokens = ["none"] + + console_logger = None + if "print" in logger_tokens: + console_logger = PrintLogger(prefix="[opd]", keys=logger_keys, precision=4) + elif "rich" in logger_tokens: + if not sys.stdout.isatty(): + console_logger = PrintLogger(prefix="[opd]", keys=logger_keys, precision=4) + else: + console_logger = RichLiveLogger( + keys=logger_keys, + spark_key="train/avg_total_reward", + history=100, + precision=4, + ) + + wandb_logger = None + if "wandb" in logger_tokens: + wandb_logger = WandbLogger(config=dict(vars(args))) + + if logger_tokens != ["none"]: + if console_logger and wandb_logger: + train_logger = TeeLogger(console_logger, wandb_logger) + else: + train_logger = console_logger or wandb_logger + + # Eval reducers + eval_reducers = { + "accuracy": Reducer(kind="count_true", source="correct", normalize_by="samples", as_percent=True), + "parse_error_rate": Reducer(kind="count_true", source="parse_error", normalize_by="samples", as_percent=True), + "avg_completion_tokens": Reducer(kind="mean", source="completion_length"), + } + + # Create trainer + trainer = Trainer( + model=model, + algo=algo, + batch_source=batch_source, + publisher=publisher, + enable_gradient_checkpointing=True, + cfg=cfg, + checkpoint_config=checkpoint_cfg, + train_logger=train_logger, + reducers=reducers, + evaluator=( + None + if not eval_samples + else EngineEvaluator( + engine=RolloutEngine(env_registry=env_registry, protocol_registry=protocol_registry), + requests_fn=lambda: [ + RolloutRequest( + env=EnvSpec( + kind="gsm8k", + kwargs={"sample": sample}, + ), + protocol=ProtocolSpec(kind="single_agent"), + env_seed=idx, + sampling_seed=idx, + inference=InferenceSpec( + sampling=SamplingParams( + temperature=args.eval_temperature, + max_tokens=args.max_completion_tokens, + ), + return_=ReturnSpec.for_eval(return_token_ids=True), + ), + num_episodes=1, + meta={"eval_sample_index": idx, "question_id": sample.get("id", idx)}, + ) + for idx, sample in enumerate(eval_samples) + ], + reducers=eval_reducers, + max_steps=1, + timeout_s=cfg.eval_timeout_s, + concurrency=cfg.eval_concurrency, + ) + ), + ) + + # Train + print(f"\nStarting OPD training for {args.train_steps} steps...") + print(f" Student: {args.student_model}") + print(f" Teacher: {args.teacher_model}") + print(f" KL coefficient: {args.kl_coeff}") + print() + + try: + trainer.train_sync(args.train_steps) + except RequestsExhausted: + print("No more training samples; stopping.") + + if args.final_save: + try: + ckpt_path = trainer.save_checkpoint(metadata={"final": True}) + print(f"Final checkpoint saved to: {ckpt_path}") + except RuntimeError: + pass # No checkpointer configured + + print("\nTraining complete!") + + +if __name__ == "__main__": + main() diff --git a/src/ludic/agents/base_agent.py b/src/ludic/agents/base_agent.py index 55bea28..53c0d44 100644 --- a/src/ludic/agents/base_agent.py +++ b/src/ludic/agents/base_agent.py @@ -1,7 +1,7 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Mapping, TYPE_CHECKING +from dataclasses import dataclass, field +from typing import Any, Awaitable, Dict, List, Optional, Tuple, Mapping, TYPE_CHECKING, Union import torch @@ -13,6 +13,7 @@ if TYPE_CHECKING: from ludic.inference.chat_template import ChatTemplate + from ludic.training.scoring import IntrinsicScorer, TokenLevelScorer, ActionLevelScorer _DEFAULT_INCOMPLETE_FEEDBACK = ( "Your response was cut off because it exceeded the token limit. " @@ -65,6 +66,8 @@ class AgentActStep: loop_index: int tool_calls: Optional[List[Dict[str, Any]]] = None tool_results: Optional[List[Dict[str, Any]]] = None + intrinsic_scores: Dict[str, Union[List[float], float]] = field(default_factory=dict) + pending_score_tasks: Dict[str, Awaitable[Union[List[float], float]]] = field(default_factory=dict) @dataclass @@ -97,6 +100,7 @@ def __init__( incomplete_completion_penalty: float = -0.1, incomplete_completion_feedback: str = _DEFAULT_INCOMPLETE_FEEDBACK, chat_template: Optional["ChatTemplate"] = None, + scorers: Optional[List["IntrinsicScorer"]] = None, ) -> None: """ Initializes the Agent. @@ -113,6 +117,9 @@ def __init__( its completion is cut off. chat_template: ChatTemplate for token-in mode. If None, the agent will try to build an HFChatTemplate from client.tokenizer. + scorers: Optional list of intrinsic scorers (TokenLevelScorer or + ActionLevelScorer) to run after each action. Scores are attached + to AgentActStep.intrinsic_scores. """ self._client = client self._model = model @@ -132,6 +139,7 @@ def __init__( chat_template = HFChatTemplate(tokenizer) self._chat_template = chat_template + self._scorers: List["IntrinsicScorer"] = scorers or [] self.last_info: Dict[str, Any] = {} async def _infer_once( @@ -252,6 +260,31 @@ async def act( # 5. Parse (format the raw text action) parse_result = self._parser(raw_action) + # 6. Fire off intrinsic scorers (don't await - resolve later for parallelism) + pending_score_tasks: Dict[str, Awaitable[Union[List[float], float]]] = {} + if self._scorers: + from ludic.training.scoring import TokenLevelScorer, ActionLevelScorer + + prompt_text = self._chat_template.apply( + messages, add_generation_prompt=True + ).prompt_text + + for scorer in self._scorers: + if isinstance(scorer, TokenLevelScorer): + # Fire and forget - create task but don't await + task = asyncio.create_task( + scorer.score_tokens( + list(token_trace.prompt_token_ids), + list(token_trace.completion_token_ids), + ) + ) + pending_score_tasks[scorer.name] = task + elif isinstance(scorer, ActionLevelScorer): + task = asyncio.create_task( + scorer.score_action(prompt_text, raw_action) + ) + pending_score_tasks[scorer.name] = task + step = AgentActStep( prompt_messages=messages, action=raw_action, @@ -260,6 +293,7 @@ async def act( trace=token_trace, action_target="env", loop_index=0, + pending_score_tasks=pending_score_tasks, ) return AgentActResult(steps=[step]) diff --git a/src/ludic/interaction/multi_agent.py b/src/ludic/interaction/multi_agent.py index 788edea..8d21f26 100644 --- a/src/ludic/interaction/multi_agent.py +++ b/src/ludic/interaction/multi_agent.py @@ -172,6 +172,8 @@ async def run( turn_id=turn_ids[agent_id], tool_calls=act_step.tool_calls, tool_results=act_step.tool_results, + intrinsic_scores=act_step.intrinsic_scores, + pending_score_tasks=act_step.pending_score_tasks, ) collector.add(agent_id, agent_step) step_indices[agent_id] += 1 diff --git a/src/ludic/interaction/single_agent.py b/src/ludic/interaction/single_agent.py index af6543d..bc0985a 100644 --- a/src/ludic/interaction/single_agent.py +++ b/src/ludic/interaction/single_agent.py @@ -179,6 +179,8 @@ async def run( turn_id=turn_id, tool_calls=act_step.tool_calls, tool_results=act_step.tool_results, + intrinsic_scores=act_step.intrinsic_scores, + pending_score_tasks=act_step.pending_score_tasks, ) steps.append(agent_step) turn_agent_step_ids.append(agent_step.id) @@ -255,6 +257,8 @@ async def run( turn_id=turn_id, tool_calls=act_step.tool_calls, tool_results=act_step.tool_results, + intrinsic_scores=act_step.intrinsic_scores, + pending_score_tasks=act_step.pending_score_tasks, ) steps.append(agent_step) turn_agent_step_ids.append(agent_step.id) diff --git a/src/ludic/training/__init__.py b/src/ludic/training/__init__.py index ffd0a74..d19ef0f 100644 --- a/src/ludic/training/__init__.py +++ b/src/ludic/training/__init__.py @@ -26,6 +26,8 @@ make_gspo, make_cispo, make_sft, + make_opd, + make_gspo_opd, ) from .credit_assignment import ( GroupNormalizedReturn, @@ -33,6 +35,8 @@ PerStepReward, EpisodicReturn, ConstantCredit, + CreditModifier, + KLCreditModifier, ) from .loss import ( Loss, @@ -42,6 +46,7 @@ TokenClippedSurrogateLoss, CISPOLoss, KLLoss, + ReverseKLLoss, EntropyBonus, LossTerm, CompositeLoss, @@ -84,12 +89,17 @@ "make_gspo", "make_cispo", "make_sft", + "make_opd", + "make_gspo_opd", # Credit assignment "GroupNormalizedReturn", "MonteCarloReturn", "PerStepReward", "EpisodicReturn", "ConstantCredit", + # Credit modifiers (Level 2: Advantage Modification) + "CreditModifier", + "KLCreditModifier", # Losses "Loss", "ReinforceLoss", @@ -98,6 +108,7 @@ "TokenClippedSurrogateLoss", "CISPOLoss", "KLLoss", + "ReverseKLLoss", "EntropyBonus", "LossTerm", "CompositeLoss", diff --git a/src/ludic/training/algorithm.py b/src/ludic/training/algorithm.py index 720d447..d2f9d0f 100644 --- a/src/ludic/training/algorithm.py +++ b/src/ludic/training/algorithm.py @@ -1,7 +1,7 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Any, Dict, Mapping, Optional, Protocol +from dataclasses import dataclass, field +from typing import Any, Dict, List, Mapping, Optional, Protocol from jaxtyping import Float from torch import nn, Tensor @@ -15,8 +15,15 @@ TokenClippedSurrogateLoss, CISPOLoss, MaskedCausalLMCrossEntropyLoss, + ReverseKLLoss, +) +from ludic.training.credit_assignment import ( + MonteCarloReturn, + GroupNormalizedReturn, + ConstantCredit, + CreditModifier, + KLCreditModifier, ) -from ludic.training.credit_assignment import MonteCarloReturn, GroupNormalizedReturn, ConstantCredit Batch = Mapping[str, Tensor] @@ -29,18 +36,37 @@ def __call__(self, saw_batch: SAWBatch) -> SAWBatch: ... @dataclass class RLAlgorithm: """ - Full RL algorithm = credit assignment + loss. - - - credit_assigner: maps Rollouts -> per-step scalar credits - (e.g. discounted returns / advantages) - - loss: consumes a collated batch (built from SAWBatch) and produces - a scalar loss and stats. - - name: identifier for logging / checkpoints + Full RL algorithm = credit assignment + credit modifiers + loss. + + Pipeline: + Rollouts → CreditAssigner → SAWBatch → Collator → Batch + ↓ + CreditModifiers + ↓ + Modified Batch + ↓ + Loss + + Components: + - credit_assigner: Rollouts → per-step scalar credits (advantages) + - credit_modifiers: Batch → Modified Batch (add per-token signals to advantages) + - loss: Batch + Logits → scalar loss + + CreditModifiers (Level 2: Advantage Modification): + Modify advantages AFTER collation, BEFORE loss. Signals added here: + - Go through importance sampling (multiplied by ratio) + - Use rollout-time values (old policy) + - Interact with task rewards through the same loss + + Example: KLCreditModifier adds teacher KL penalty to advantages. + + See docs/composition.md for the full composition level documentation. """ name: str credit_assigner: CreditAssigner loss: Loss + credit_modifiers: List[CreditModifier] = field(default_factory=list) preprocess: Optional[PreprocessFn] = None def compute_loss( @@ -49,8 +75,22 @@ def compute_loss( batch: Batch, ) -> tuple[Tensor, Dict[str, Any]]: """ - Runs the forward pass once and delegates to the Loss object. + Apply credit modifiers, run forward pass, compute loss. + + Returns: + Tuple of (loss, stats_dict) where stats_dict includes: + - Modifier metrics namespaced as "{modifier.name}/{metric}" + - Loss metrics from self.loss.compute() """ + all_stats: Dict[str, Any] = {} + + # --- Apply credit modifiers (Level 2: Advantage Modification) --- + for modifier in self.credit_modifiers: + batch, modifier_stats = modifier.modify(batch) + # Namespace modifier metrics + for key, value in modifier_stats.items(): + all_stats[f"{modifier.name}/{key}"] = value + # --- Run the forward pass --- input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] @@ -60,8 +100,11 @@ def compute_loss( ) logits: Logits = outputs.logits - # Pass the resulting logits to the loss function - return self.loss.compute(logits, batch) + # --- Compute loss --- + loss, loss_stats = self.loss.compute(logits, batch) + all_stats.update(loss_stats) + + return loss, all_stats # --------------------------------------------------------------------------- @@ -450,3 +493,178 @@ def make_sft( credit_assigner=credit_assigner, loss=loss, ) + + +# --------------------------------------------------------------------------- +# On-Policy Distillation (OPD) +# --------------------------------------------------------------------------- + + +def make_opd( + *, + kl_coeff: float = 1.0, + length_normalize: bool = True, + name: str = "opd", +) -> RLAlgorithm: + """ + On-Policy Distillation (OPD). + + Dense per-token supervision from a teacher model. The student samples + trajectories, teacher provides logprobs, and training minimizes reverse KL. + + Core idea: Sample from student, grade each token with teacher logprobs. + This combines on-policy learning (samples from student) with dense + supervision (per-token teacher signal), achieving better compute + efficiency than sparse RL rewards. + + Credit assignment: ConstantCredit(1.0) - the per-step "advantage" is + implicit in the loss (negative reverse KL per token). + + Loss: ReverseKLLoss - minimizes KL(student || teacher) per token. + + Prerequisites: + - Agent must have a TokenLevelScorer with name="teacher_logps" + - Use make_vllm_teacher_scorer() to create the scorer + - Scorer runs during Agent.act() and scores flow through to training + + Args: + kl_coeff: Coefficient for KL loss. Higher values push the student + harder towards the teacher's distribution. Default 1.0. + length_normalize: If True (default), divide per-sample loss by number + of action tokens. This keeps gradients stable across sequence lengths. + name: Algorithm name for logging/metrics. + + Example: + ```python + from ludic.training import Trainer, make_opd + from ludic.training.scoring import make_vllm_teacher_scorer + from ludic.agent import Agent + + # Create teacher scorer + teacher_scorer = make_vllm_teacher_scorer( + base_url="http://localhost:8001", + model="Qwen/Qwen3-32B", + ) + + # Attach to agent - scores flow through automatically + agent = Agent( + client=client, + ..., + scorers=[teacher_scorer], + ) + + # Train with OPD + trainer = Trainer(model=model, algo=make_opd(), ...) + ``` + + Reference: https://thinkingmachines.ai/blog/on-policy-distillation + + Note: + This uses Loss Composition (Level 3). For OPD where KL goes through + importance sampling (Level 2: Advantage Modification), use make_gspo_opd() instead. + """ + credit_assigner: CreditAssigner = ConstantCredit(value=1.0) + loss: Loss = ReverseKLLoss(coeff=kl_coeff, length_normalize=length_normalize) + + return RLAlgorithm( + name=name, + credit_assigner=credit_assigner, + loss=loss, + ) + + +def make_gspo_opd( + *, + group_size: int, + kl_coeff: float = 1.0, + group_normalize_adv: bool = True, + positive_only: bool = False, + clip_eps_low: float = 3e-4, + clip_eps_high: float = 4e-4, + length_normalize: bool = True, + ratio_clip: Optional[float] = None, + drop_zero_weight: bool = False, + drop_zero_weight_eps: float = 1e-4, + name: str = "gspo_opd", +) -> RLAlgorithm: + """ + GSPO + OPD hybrid using advantage composition. + + Combines: + - GSPO: Task rewards with group-normalized advantages + - OPD: Teacher KL penalty added to advantages (Level 2: Advantage Modification) + + KL is computed at rollout time and added to advantages BEFORE the loss. + This means: + - KL goes through importance sampling (multiplied by ratio) + - KL uses old policy logprobs (from rollout), not current policy + - All signals interact through the same loss function + + Pipeline: + 1. GroupNormalizedReturn computes task-based advantages + 2. KLCreditModifier adds negative KL to advantages + 3. ClippedSurrogateLoss computes policy gradient with modified advantages + + Args: + group_size: Number of rollouts per group for advantage normalization. + kl_coeff: Coefficient for KL penalty. Higher = stronger teacher matching. + group_normalize_adv: Normalize advantages within each group. + positive_only: Clip negative advantages to zero. + clip_eps_low: Lower PPO clipping epsilon. + clip_eps_high: Upper PPO clipping epsilon. + length_normalize: Normalize loss by number of action tokens. + ratio_clip: Optional upper bound for ratio truncation. + drop_zero_weight: Drop zero-advantage samples before collation. + drop_zero_weight_eps: Epsilon for zero-weight detection. + name: Algorithm name for logging. + + Prerequisites: + - Rollouts must return actor logprobs (ReturnSpec.for_rl()) + - Agent must have a teacher scorer (make_vllm_teacher_scorer()) + - Rollouts must have group_id for GSPO advantage normalization + + Example: + ```python + from ludic.training import make_gspo_opd + from ludic.training.scoring import make_vllm_teacher_scorer + + teacher_scorer = make_vllm_teacher_scorer( + base_url="http://localhost:8001", + model="Qwen/Qwen3-32B", + ) + + agent = Agent(client=client, ..., scorers=[teacher_scorer]) + algo = make_gspo_opd(group_size=8, kl_coeff=1.0) + trainer = Trainer(model=model, algo=algo, ...) + ``` + + Reference: https://thinkingmachines.ai/blog/on-policy-distillation + """ + credit_assigner: CreditAssigner = GroupNormalizedReturn( + group_size=group_size, + normalize_adv=group_normalize_adv, + positive_only=positive_only, + ) + + credit_modifiers = [KLCreditModifier(coeff=kl_coeff)] + + loss: Loss = ClippedSurrogateLoss( + clip_eps_low=clip_eps_low, + clip_eps_high=clip_eps_high, + length_normalize=length_normalize, + ratio_clip=ratio_clip, + ) + + preprocess_fns = [] + if drop_zero_weight: + preprocess_fns.append(lambda batch: drop_zero_weight_samples(batch, eps=drop_zero_weight_eps)) + preprocess_fns.append(validate_actor_logps) + preprocess = compose_preprocess(*preprocess_fns) + + return RLAlgorithm( + name=name, + credit_assigner=credit_assigner, + credit_modifiers=credit_modifiers, + loss=loss, + preprocess=preprocess, + ) diff --git a/src/ludic/training/batching/micro_batching.py b/src/ludic/training/batching/micro_batching.py index 98b27fb..f76aba2 100644 --- a/src/ludic/training/batching/micro_batching.py +++ b/src/ludic/training/batching/micro_batching.py @@ -6,7 +6,7 @@ import torch from torch import Tensor -from ludic.training.types import SAWItem, ActorTokenLogps, SampleAttachments +from ludic.training.types import SAWItem, ActorTokenLogps, TeacherLogprobs, SampleAttachments @dataclass(frozen=True) @@ -95,6 +95,28 @@ def collate_saw_items( batch["actor_logps"] = actor_logps_batch batch["old_logp_action"] = old_logp_action + + # teacher_logps is optional; required for on-policy distillation (OPD). + teachers = [it.teacher_logps for it in items] + if any(teacher is not None for teacher in teachers): + if any(teacher is None for teacher in teachers): + raise ValueError( + "Mixed presence of teacher_logps; either provide it for all samples or none." + ) + teacher_logps_batch = torch.zeros((batch_size, max_len), dtype=torch.float32, device=device) + for b, (it, teacher) in enumerate(zip(items, teachers)): + assert teacher is not None + token_logps = torch.as_tensor(teacher.token_logps, dtype=torch.float32, device=device) + positions = torch.nonzero(action_mask[b] > 0.0, as_tuple=False).flatten() + if token_logps.numel() != positions.numel(): + raise ValueError( + f"Length mismatch between teacher_logps ({token_logps.numel()}) " + f"and the number of action tokens ({positions.numel()})." + ) + teacher_logps_batch[b, positions] = token_logps + + batch["teacher_logps"] = teacher_logps_batch + return batch @@ -112,9 +134,20 @@ def _truncate_item(item: SAWItem, max_seq_len: int) -> SAWItem: prompt_tokens = len(input_ids) - action_tokens attachments = item.attachments + new_actor_logps = None + new_teacher_logps = None if attachments.actor_logps is not None: - token_logps = attachments.actor_logps.token_logps[:action_tokens] - attachments = SampleAttachments(actor_logps=ActorTokenLogps(token_logps=token_logps)) + new_actor_logps = ActorTokenLogps( + token_logps=attachments.actor_logps.token_logps[:action_tokens] + ) + if attachments.teacher_logps is not None: + new_teacher_logps = TeacherLogprobs( + token_logps=attachments.teacher_logps.token_logps[:action_tokens] + ) + attachments = SampleAttachments( + actor_logps=new_actor_logps, + teacher_logps=new_teacher_logps, + ) meta = dict(item.meta) meta["seq_len_truncated"] = True diff --git a/src/ludic/training/batching/rollout_engine.py b/src/ludic/training/batching/rollout_engine.py index d9deee0..b8d3d7e 100644 --- a/src/ludic/training/batching/rollout_engine.py +++ b/src/ludic/training/batching/rollout_engine.py @@ -18,6 +18,7 @@ SAWItem, SAWBatch, ActorTokenLogps, + TeacherLogprobs, SampleAttachments, RolloutRequest, ProtocolSpec, @@ -281,6 +282,10 @@ def _build_turn_saw_item( logprobs_mode: Optional[bool] = None logprobs: List[float] = [] + # Collect intrinsic scores (token-level scores are concatenated across steps) + token_scores: Dict[str, List[float]] = {} + action_scores: Dict[str, float] = {} + def _extend(ids: Sequence[int], *, is_action: bool) -> None: input_ids.extend(ids) attention_mask.extend([1] * len(ids)) @@ -332,6 +337,17 @@ def _extend(ids: Sequence[int], *, is_action: bool) -> None: logprobs_mode = True logprobs.extend(completion_logprobs) + # Collect intrinsic scores from agent step + for name, scores in step.intrinsic_scores.items(): + if isinstance(scores, list): + # Token-level scores: concatenate across steps + if name not in token_scores: + token_scores[name] = [] + token_scores[name].extend(scores) + else: + # Action-level scores: use last step's score (or could sum/average) + action_scores[name] = float(scores) + if require_chosen_logprobs and not logprobs_mode: raise ValueError( f"Missing completion_logprobs for rollout {rollout.id}, turn {turn.turn_id!r}, " @@ -381,11 +397,25 @@ def _extend(ids: Sequence[int], *, is_action: bool) -> None: } ) - attachments = SampleAttachments() - if logprobs_mode: - attachments = SampleAttachments( - actor_logps=ActorTokenLogps(token_logps=logprobs) - ) + # Build attachments + actor_logps = ActorTokenLogps(token_logps=logprobs) if logprobs_mode else None + teacher_logps = None + if "teacher_logps" in token_scores: + teacher_logps = TeacherLogprobs(token_logps=token_scores["teacher_logps"]) + + attachments = SampleAttachments( + actor_logps=actor_logps, + teacher_logps=teacher_logps, + ) + + # Store other intrinsic scores in meta for now + # (could extend SampleAttachments later for more score types) + if token_scores: + meta["intrinsic_token_scores"] = { + k: v for k, v in token_scores.items() if k != "teacher_logps" + } + if action_scores: + meta["intrinsic_action_scores"] = action_scores return ( SAWItem( @@ -520,6 +550,52 @@ async def _run_one_request( return rollouts + async def _resolve_pending_scores(self, rollouts: List[Rollout]) -> None: + """ + Resolve all pending score tasks from agent steps. + + Intrinsic scorers (e.g., teacher logprobs) fire-and-forget during Agent.act(). + This method awaits all pending tasks and stores results in intrinsic_scores. + + This design allows scoring to run in parallel with: + - Other agent steps within the same rollout + - Other rollouts running concurrently + + After this method returns, all AgentStep.intrinsic_scores are populated. + """ + # Collect all pending tasks with their target locations + pending: List[Tuple[AgentStep, str, asyncio.Task]] = [] + + for rollout in rollouts: + for step in rollout.steps: + if isinstance(step, AgentStep) and step.pending_score_tasks: + for name, task in step.pending_score_tasks.items(): + pending.append((step, name, task)) + + if not pending: + return + + # Await all tasks concurrently + tasks = [t for (_, _, t) in pending] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Store results back in steps + for (step, name, _), result in zip(pending, results): + if isinstance(result, Exception): + # Log but don't fail - scorer errors shouldn't break training + import logging + logging.getLogger(__name__).warning( + f"Scorer '{name}' failed for step {step.index}: {result}" + ) + else: + step.intrinsic_scores[name] = result + + # Clear pending tasks (they're resolved now) + for rollout in rollouts: + for step in rollout.steps: + if isinstance(step, AgentStep): + step.pending_score_tasks = {} + def _append_jsonl(self, rollout: Rollout) -> None: assert self.jsonl_path is not None def _serialize_step(step: Step) -> Dict[str, Any]: @@ -645,6 +721,12 @@ async def generate_batch( - If sample_filter is provided, it's applied after SAWItems are created. - Filter returns True to KEEP a sample, False to DROP it. - Use ludic.training.filters for common predicates. + + Intrinsic Scoring: + - Agents can have intrinsic scorers (TokenLevelScorer, ActionLevelScorer). + - Scores are computed during Agent.act() and flow through AgentStep. + - Token-level scores (e.g., teacher_logps) are attached to SAWItem.attachments. + - Use with make_opd() algorithm for on-policy distillation training. """ rollouts = await self.generate_rollouts( requests=requests, @@ -652,6 +734,11 @@ async def generate_batch( timeout_s=timeout_s, concurrency=concurrency, ) + + # Resolve pending score tasks from all agent steps + # This allows scoring to run in parallel with rollout generation + await self._resolve_pending_scores(rollouts) + weights = credit_assigner.compute(rollouts) items_with_lengths: List[Tuple[SAWItem, int, int]] = [] diff --git a/src/ludic/training/batching/synced_batching.py b/src/ludic/training/batching/synced_batching.py index 99c0aef..6fd7405 100644 --- a/src/ludic/training/batching/synced_batching.py +++ b/src/ludic/training/batching/synced_batching.py @@ -18,6 +18,10 @@ class RolloutBatchSource(BatchSource): Note: RolloutEngine now concatenates each agent turn into a single training sample (one SAWItem per turn), rather than emitting per-step samples. + + For on-policy distillation (OPD), configure the Agent with a TokenLevelScorer + (e.g., make_vllm_teacher_scorer). The scorer computes teacher logprobs during + Agent.act() and they flow through to SAWItem.attachments.teacher_logps. """ def __init__( diff --git a/src/ludic/training/credit_assignment.py b/src/ludic/training/credit_assignment.py index c17c351..5d157b0 100644 --- a/src/ludic/training/credit_assignment.py +++ b/src/ludic/training/credit_assignment.py @@ -2,13 +2,176 @@ from collections import defaultdict from dataclasses import dataclass -from typing import Dict, List +from typing import Any, Dict, List, Mapping, Protocol, Tuple import torch +from torch import Tensor + from ludic.types import Rollout from ludic.training.types import RolloutStepKey +Batch = Mapping[str, Tensor] + + +# --------------------------------------------------------------------------- +# Credit Modifiers (Level 2: Advantage Modification) +# --------------------------------------------------------------------------- +# +# CreditModifiers operate on collated batches (after credit assignment) to +# modify the per-token advantages/weights before loss computation. +# +# Key use case: OPD where KL penalty is added to advantages, causing it to +# go through importance sampling like task rewards. +# +# See docs/composition.md for the full composition level documentation. +# --------------------------------------------------------------------------- + + +class CreditModifier(Protocol): + """ + Modifies batch advantages after collation, before loss computation. + + This is "Level 2: Advantage Modification" - adding per-token signals + (like KL penalty) to trajectory-level advantages from credit assignment. + + Unlike CompositeLoss (Level 3: Loss Composition), signals added here: + - Go through importance sampling (multiplied by ratio) + - Use rollout-time (old policy) values, not current policy + - Interact with task rewards through the same loss function + + Example: + >>> modifier = KLCreditModifier(coeff=1.0) + >>> batch, metrics = modifier.modify(batch) + >>> # batch["weight"] now includes KL penalty + """ + + name: str + + def modify(self, batch: Batch) -> Tuple[Batch, Dict[str, Any]]: + """ + Modify batch advantages and return metrics. + + Args: + batch: Collated batch with at least: + - "weight": [B, T] per-token advantages from credit assignment + - Other fields depend on the modifier (e.g., "actor_logps", "teacher_logps") + + Returns: + Tuple of (modified_batch, metrics_dict). + The batch should be modified in-place for efficiency. + Metrics are namespaced under self.name by the caller. + """ + ... + + +@dataclass +class KLCreditModifier: + """ + Add negative reverse KL to advantages for on-policy distillation. + + Implements KL penalty for OPD: + kl_advantage_t = -coeff * (actor_logps_t - teacher_logps_t) + + The KL is computed at rollout time (old policy), so it goes through + importance sampling when used with ratio-based losses like PPO/GSPO. + + Args: + coeff: Coefficient for KL penalty. Higher = stronger teacher matching. + name: Modifier name for logging. Metrics appear as "{name}/kl_mean", etc. + + Requires batch to have: + - "actor_logps": [B, T] old policy logprobs from rollout + - "teacher_logps": [B, T] teacher logprobs + + Example: + >>> algo = RLAlgorithm( + ... credit_assigner=GroupNormalizedReturn(group_size=8), + ... credit_modifiers=[KLCreditModifier(coeff=1.0)], + ... loss=ClippedSurrogateLoss(...), + ... ) + """ + + coeff: float = 1.0 + name: str = "kl" + + def modify(self, batch: Batch) -> Tuple[Dict[str, Tensor], Dict[str, Any]]: + if "actor_logps" not in batch: + raise KeyError( + "KLCreditModifier requires batch['actor_logps']. " + "Ensure rollouts return actor logprobs (ReturnSpec.for_rl())." + ) + if "teacher_logps" not in batch: + raise KeyError( + "KLCreditModifier requires batch['teacher_logps']. " + "Ensure agent has a teacher scorer (e.g., make_vllm_teacher_scorer())." + ) + + actor_logps = batch["actor_logps"] # [B, T] full sequence + teacher_logps = batch["teacher_logps"] # [B, T] full sequence + action_mask = batch["action_mask"] # [B, T] full sequence + weight = batch["weight"] # [B], [B, C], or [B, T] + + B, T = action_mask.shape + action_mask_f = action_mask.float() + + # Reverse KL: log π_student - log π_teacher (full sequence) + # We want to minimize this, so we add NEGATIVE KL to advantages + reverse_kl = actor_logps - teacher_logps # [B, T] + kl_penalty_full = -self.coeff * reverse_kl # [B, T] + + # Handle different weight shapes. Scalar weights are broadcast to + # per-token advantages so KL is applied per action token. + # The loss function expects weight to match ratio shape, which may be + # completion-only [B, C] rather than full sequence [B, T]. + if weight.dim() == 1: + # Turn-level [B]: broadcast to per-token advantages. + base_adv = weight.unsqueeze(-1) * action_mask_f + modified_weight = (base_adv + kl_penalty_full) * action_mask_f + elif weight.shape[-1] == T: + # Full sequence [B, T]: add KL directly, mask to action tokens. + modified_weight = (weight + kl_penalty_full) * action_mask_f + else: + # Completion-only [B, C]: extract KL for action tokens and align to completion positions. + C = weight.shape[-1] + kl_penalty_completion = torch.zeros( + B, C, device=weight.device, dtype=weight.dtype + ) + completion_mask = torch.zeros( + B, C, device=weight.device, dtype=weight.dtype + ) + for b in range(B): + action_indices = action_mask[b].nonzero(as_tuple=True)[0] + n_actions = min(action_indices.numel(), C) + if n_actions > 0: + kl_penalty_completion[b, :n_actions] = kl_penalty_full[b, action_indices[:n_actions]] + completion_mask[b, :n_actions] = 1.0 + modified_weight = (weight + kl_penalty_completion) * completion_mask + + # Create modified batch (shallow copy with updated weight) + modified_batch = dict(batch) + modified_batch["weight"] = modified_weight + + # Compute metrics (masked to action tokens) + mask_sum = action_mask.sum() + if mask_sum > 0: + masked_kl = reverse_kl * action_mask_f + kl_mean = masked_kl.sum() / mask_sum + kl_std = ((masked_kl - kl_mean * action_mask_f) ** 2).sum() / mask_sum + kl_std = kl_std.sqrt() + else: + kl_mean = reverse_kl.new_zeros(()) + kl_std = reverse_kl.new_zeros(()) + + metrics = { + "kl_mean": kl_mean.detach(), + "kl_std": kl_std.detach(), + "kl_penalty_mean": (kl_penalty_full * action_mask_f).sum().detach() / mask_sum.clamp(min=1), + } + + return modified_batch, metrics + + # ---- Credit Assigners ---- @dataclass diff --git a/src/ludic/training/loss.py b/src/ludic/training/loss.py index 18fe078..54aa89b 100644 --- a/src/ludic/training/loss.py +++ b/src/ludic/training/loss.py @@ -339,7 +339,7 @@ class ClippedSurrogateLoss: L_clip = - E[ min(r * A, clip(r, 1 - eps_low, 1 + eps_high) * A) ] Expects: - - batch["weight"]: A (advantages) [B] + - batch["weight"]: A (advantages) [B] or token-level [B, T] - batch[old_logp_key]: log π_old(a|s) [B] - input_ids / attention_mask / action_mask for π_new. @@ -365,7 +365,7 @@ def __post_init__(self) -> None: def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]]: input_ids = batch["input_ids"] action_mask = batch["action_mask"] - advantages = batch["weight"] # [B] + advantages = batch["weight"] if self.old_logp_key not in batch: raise KeyError(f"ClippedSurrogateLoss requires '{self.old_logp_key}' in batch.") @@ -387,13 +387,28 @@ def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]] if self.ratio_clip is not None: ratio = torch.clamp(ratio, max=self.ratio_clip) - unclipped = ratio * advantages - clipped = torch.clamp( + ratio_clipped = torch.clamp( ratio, 1.0 - self.clip_eps_low, 1.0 + self.clip_eps_high - ) * advantages - - obj = torch.min(unclipped, clipped) - loss = -obj.mean() + ) + if advantages.dim() == 2: + if advantages.shape != action_mask.shape: + raise ValueError( + "ClippedSurrogateLoss expects token-level weights to match action_mask shape " + f"{tuple(action_mask.shape)}, got {tuple(advantages.shape)}." + ) + adv_mask = action_mask.to(advantages.dtype) + ratio_exp = ratio.unsqueeze(-1) + ratio_clipped_exp = ratio_clipped.unsqueeze(-1) + unclipped = ratio_exp * advantages + clipped = ratio_clipped_exp * advantages + obj = torch.min(unclipped, clipped) * adv_mask + adv_denom = adv_mask.sum().clamp(min=1.0) + loss = -obj.sum() / adv_denom + else: + unclipped = ratio * advantages + clipped = ratio_clipped * advantages + obj = torch.min(unclipped, clipped) + loss = -obj.mean() ppo_clip_frac = ( (ratio > 1.0 + self.clip_eps_high) | (ratio < 1.0 - self.clip_eps_low) @@ -403,6 +418,15 @@ def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]] else: ratio_clip_frac = torch.zeros((), device=ratio.device, dtype=ratio.dtype) + if advantages.dim() == 2: + masked_adv = advantages * adv_mask + adv_mean = masked_adv.sum() / adv_denom + adv_var = ((masked_adv - adv_mean) * adv_mask).pow(2).sum() / adv_denom + adv_std = adv_var.sqrt() + else: + adv_mean = advantages.mean() + adv_std = advantages.std(unbiased=False) + stats = { "loss": loss.detach(), "ratio_mean": ratio.mean().detach(), @@ -410,8 +434,8 @@ def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]] "clip_frac": ppo_clip_frac.detach(), "ratio_clip_frac": ratio_clip_frac.detach(), "kl_actor_policy": mismatch_kl.mean().detach(), - "adv_mean": advantages.mean().detach(), - "adv_std": advantages.std(unbiased=False).detach(), + "adv_mean": adv_mean.detach(), + "adv_std": adv_std.detach(), "logp_mean": logp_action.mean().detach(), } return loss, stats @@ -737,6 +761,91 @@ def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]] return loss, stats +# --------------------------------------------------------------------------- +# On-Policy Distillation +# --------------------------------------------------------------------------- + + +@dataclass +class ReverseKLLoss: + """ + On-Policy Distillation loss: per-token reverse KL from student to teacher. + + Loss = E[ log π_student(a_t|s_t) - log π_teacher(a_t|s_t) ] + + This pushes the student to match the teacher's distribution on the + student's own samples (on-policy). The loss is minimized when the student + assigns the same probability as the teacher to each token. + + Key insight: Unlike forward KL which mode-covers, reverse KL is mode-seeking. + The student learns to approximate one specific behavior (the teacher's) + rather than spreading across all plausible behaviors. + + Expects: + - batch["input_ids"]: [B, T] tokenized sequences + - batch["action_mask"]: [B, T] mask for completion tokens + - batch["teacher_logps"]: [B, T] per-token teacher logprobs + + Reference: https://thinkingmachines.ai/blog/on-policy-distillation + """ + + coeff: float = 1.0 + length_normalize: bool = False + + @jaxtyped(typechecker=typechecker) + def compute(self, logits: Logits, batch: Batch) -> Tuple[Tensor, Dict[str, Any]]: + if "teacher_logps" not in batch: + raise KeyError( + "ReverseKLLoss requires batch['teacher_logps']. " + "Ensure teacher logprobs are computed during rollout generation " + "(e.g., via RolloutEngine with teacher_client parameter)." + ) + + input_ids = batch["input_ids"] + action_mask = batch["action_mask"] + teacher_logps = batch["teacher_logps"] + + # Current policy logprobs: [B, T-1] + token_logp = compute_token_logp(logits, input_ids) + + # Shift teacher logps to align with predictions + # teacher_logps[t] is logprob of token t, but we predict token t from t-1 + teacher_logps_shifted = teacher_logps[:, 1:] # [B, T-1] + token_mask = action_mask[:, 1:].to(token_logp.dtype) # [B, T-1] + token_counts = token_mask.sum(dim=-1).clamp(min=1.0) # [B] + + # Reverse KL: log(student) - log(teacher) + # Positive when student is more confident than teacher on a token + # Negative when teacher is more confident + reverse_kl = token_logp - teacher_logps_shifted + + # Masked sum over tokens + per_sample_kl = (reverse_kl * token_mask).sum(dim=-1) # [B] + + if self.length_normalize: + per_sample_kl = per_sample_kl / token_counts + + loss = self.coeff * per_sample_kl.mean() + + # Stats - compute over masked tokens only + mask = token_mask > 0 + if mask.any(): + masked_kl = reverse_kl.masked_select(mask) + kl_mean = masked_kl.mean() + kl_std = masked_kl.std(unbiased=False) + else: + kl_mean = loss.new_zeros(()) + kl_std = loss.new_zeros(()) + + stats: Dict[str, Any] = { + "loss": loss.detach(), + "reverse_kl_mean": kl_mean.detach(), + "reverse_kl_std": kl_std.detach(), + "avg_action_tokens": token_counts.mean().detach(), + } + return loss, stats + + # --------------------------------------------------------------------------- # Composite loss # --------------------------------------------------------------------------- diff --git a/src/ludic/training/scoring.py b/src/ludic/training/scoring.py new file mode 100644 index 0000000..2e712f3 --- /dev/null +++ b/src/ludic/training/scoring.py @@ -0,0 +1,200 @@ +""" +Intrinsic scoring protocols for agents. + +Intrinsic scores are agent-local evaluations of action quality, computed +during rollout generation. They are analogous to the parser (which validates +action syntax) but evaluate action quality instead. + +Two scorer types: +- TokenLevelScorer: Per-token scores (e.g., teacher logprobs for OPD) +- ActionLevelScorer: Scalar per-action scores (e.g., LLM-as-judge) + +Scores are attached to AgentActStep and flow into SAWItem for training. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Protocol, Union, runtime_checkable + + +@runtime_checkable +class TokenLevelScorer(Protocol): + """ + Per-token intrinsic scoring. + + Use cases: + - Teacher logprobs for On-Policy Distillation (OPD) + - Token-level reward models + - Per-token confidence scores + + The scorer receives prompt and completion token IDs, returns one score + per completion token. + """ + + name: str + + async def score_tokens( + self, + prompt_token_ids: List[int], + completion_token_ids: List[int], + ) -> List[float]: + """ + Compute per-token scores for the given completion. + + Args: + prompt_token_ids: Prompt token IDs (for conditioning context). + completion_token_ids: Completion token IDs to score. + + Returns: + List of scores, one per completion token. + Length must equal len(completion_token_ids). + """ + ... + + +@runtime_checkable +class ActionLevelScorer(Protocol): + """ + Per-action intrinsic scoring (scalar). + + Use cases: + - LLM-as-a-judge + - Verifier models + - Self-consistency scores + + The scorer receives the full context and returns a single scalar. + """ + + name: str + + async def score_action(self, prompt: str, completion: str) -> float: + """ + Compute a scalar score for the action. + + Args: + prompt: The prompt text (rendered messages). + completion: The agent's completion text. + + Returns: + Scalar score for the action. + """ + ... + + +IntrinsicScorer = Union[TokenLevelScorer, ActionLevelScorer] + + +@dataclass +class VLLMTeacherScorer: + """ + TokenLevelScorer backed by vLLM server. + + Computes teacher logprobs via teacher-forced prefill using + the /v1/completions endpoint with echo=True and max_tokens=0. + + Assumptions: + - Teacher and student use the SAME TOKENIZER. Token IDs from + the student are sent directly to the teacher. If tokenizers + differ, logprobs will be meaningless or the request will fail. + - The full sequence (prompt + completion) fits within the teacher's + context window. If the sequence exceeds the context limit, vLLM + may truncate and return fewer logprobs than completion tokens, + causing a length mismatch error during batch collation. + + For models with different tokenizers, you would need a custom scorer + that re-tokenizes the text with the teacher's tokenizer. + """ + + base_url: str + model: str + name: str = "teacher_logps" + timeout: float = 60.0 + + async def score_tokens( + self, + prompt_token_ids: List[int], + completion_token_ids: List[int], + ) -> List[float]: + """ + Compute per-token logprobs from teacher model. + + Uses vLLM's echo mode to get logprobs for existing tokens. + The full sequence (prompt + completion) is sent so the teacher + can properly condition on the prompt context. + """ + import aiohttp + + # Send full sequence so teacher conditions on prompt + full_sequence = prompt_token_ids + completion_token_ids + prompt_len = len(prompt_token_ids) + + url = f"{self.base_url}/v1/completions" + payload = { + "model": self.model, + "prompt": full_sequence, + "max_tokens": 0, + "echo": True, + "logprobs": 1, + } + + async with aiohttp.ClientSession() as session: + async with session.post( + url, + json=payload, + timeout=aiohttp.ClientTimeout(total=self.timeout), + ) as resp: + resp.raise_for_status() + data = await resp.json() + + logprobs_data = data["choices"][0].get("logprobs", {}) + token_logprobs = logprobs_data.get("token_logprobs", []) + + # Extract only completion token logprobs (skip prompt tokens) + # First token of full sequence has logprob=None, but completion tokens + # all have valid logprobs since they're conditioned on prior tokens + completion_logprobs = token_logprobs[prompt_len:] + + return [float(lp) if lp is not None else 0.0 for lp in completion_logprobs] + + +def make_vllm_teacher_scorer( + base_url: str, + model: str, + *, + name: str = "teacher_logps", + timeout: float = 60.0, +) -> TokenLevelScorer: + """ + Create a TokenLevelScorer backed by vLLM for teacher logprobs. + + This is used for On-Policy Distillation (OPD) where a teacher model + provides per-token supervision via teacher-forced prefill. + + Important: The teacher model MUST use the same tokenizer as the student. + Token IDs are passed directly without re-tokenization. Also ensure the + full sequence (prompt + completion) fits within the teacher's context + window to avoid truncation errors. + + Args: + base_url: vLLM server URL (e.g., "http://localhost:8001"). + model: Teacher model name (must share tokenizer with student). + name: Attachment key for scores (default: "teacher_logps"). + timeout: Request timeout in seconds. + + Returns: + TokenLevelScorer that computes teacher logprobs. + + Example: + >>> teacher = make_vllm_teacher_scorer( + ... base_url="http://localhost:8001", + ... model="Qwen/Qwen2.5-7B-Instruct", + ... ) + >>> agent = Agent(client=client, ..., scorers=[teacher]) + """ + return VLLMTeacherScorer( + base_url=base_url, + model=model, + name=name, + timeout=timeout, + ) diff --git a/src/ludic/training/types.py b/src/ludic/training/types.py index 7ffcfa9..b472a3f 100644 --- a/src/ludic/training/types.py +++ b/src/ludic/training/types.py @@ -127,6 +127,27 @@ def __post_init__(self) -> None: raise TypeError("ActorTokenLogps.token_logps must be a List[float].") +@dataclass(frozen=True) +class TeacherLogprobs: + """ + Per-token logprobs from a teacher model, used for on-policy distillation. + + `token_logps[i]` corresponds to the teacher's logprob for the chosen token + at completion position i. This is computed by querying the teacher model + on the student's sampled tokens. + + Reference: https://thinkingmachines.ai/blog/on-policy-distillation + """ + + token_logps: List[float] + + def __post_init__(self) -> None: + if not isinstance(self.token_logps, list) or not all( + isinstance(v, (int, float)) for v in self.token_logps + ): + raise TypeError("TeacherLogprobs.token_logps must be a List[float].") + + @dataclass class SampleAttachments: """ @@ -137,6 +158,7 @@ class SampleAttachments: """ actor_logps: Optional[ActorTokenLogps] = None + teacher_logps: Optional[TeacherLogprobs] = None class HasActorLogps(Protocol): @@ -186,6 +208,10 @@ class SAWItem: def actor_logps(self) -> Optional[ActorTokenLogps]: return self.attachments.actor_logps + @property + def teacher_logps(self) -> Optional[TeacherLogprobs]: + return self.attachments.teacher_logps + @dataclass class SAWBatch: """ diff --git a/src/ludic/types.py b/src/ludic/types.py index 62ebfea..291540a 100644 --- a/src/ludic/types.py +++ b/src/ludic/types.py @@ -1,6 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field, asdict -from typing import Any, Dict, List, Union, Optional, Literal, Mapping +from typing import Any, Awaitable, Dict, List, Union, Optional, Literal, Mapping import logging import time import uuid @@ -184,6 +184,8 @@ class AgentStep: parent_id: Optional[str] = None tool_calls: Optional[List[Dict[str, Any]]] = None tool_results: Optional[List[Dict[str, Any]]] = None + intrinsic_scores: Dict[str, Any] = field(default_factory=dict) + pending_score_tasks: Dict[str, Awaitable[Any]] = field(default_factory=dict) @dataclass class EnvironmentStep: diff --git a/tests/test_credit_assignment.py b/tests/test_credit_assignment.py index f02b61b..ffd4cb4 100644 --- a/tests/test_credit_assignment.py +++ b/tests/test_credit_assignment.py @@ -1,5 +1,6 @@ import pytest import math +import torch from ludic.types import Rollout, EnvironmentStep, TokenTrace from ludic.training.credit_assignment import ( @@ -7,6 +8,7 @@ EpisodicReturn, PerStepReward, GroupNormalizedReturn, + KLCreditModifier, ) # ---- Helper to build a simple rollout ---- @@ -243,3 +245,293 @@ def test_group_normalized_return_invalid_group_size(): with pytest.raises(ValueError, match="group_size must be positive"): GroupNormalizedReturn(group_size=-1) + + +# ---- KLCreditModifier Tests ---- + +def _make_batch( + *, + weight: torch.Tensor, + actor_logps: torch.Tensor, + teacher_logps: torch.Tensor, + action_mask: torch.Tensor, +) -> dict: + """Helper to create a batch dict for KLCreditModifier tests.""" + return { + "weight": weight, + "actor_logps": actor_logps, + "teacher_logps": teacher_logps, + "action_mask": action_mask, + } + + +def test_kl_credit_modifier_basic(): + """Test basic KL penalty addition to advantages.""" + # Batch of 2 samples, 4 tokens each + # actor_logps - teacher_logps = reverse KL + # KL penalty = -coeff * reverse_kl + weight = torch.zeros(2, 4) + actor_logps = torch.tensor([ + [-1.0, -2.0, -1.5, -0.5], # sample 0 + [-2.0, -1.0, -3.0, -1.0], # sample 1 + ]) + teacher_logps = torch.tensor([ + [-1.5, -1.5, -1.5, -1.5], # sample 0: teacher + [-1.5, -1.5, -1.5, -1.5], # sample 1: teacher + ]) + action_mask = torch.tensor([ + [0, 1, 1, 1], # sample 0: first token is prompt + [0, 0, 1, 1], # sample 1: first two tokens are prompt + ]) + + batch = _make_batch( + weight=weight, + actor_logps=actor_logps, + teacher_logps=teacher_logps, + action_mask=action_mask, + ) + + modifier = KLCreditModifier(coeff=1.0) + modified_batch, metrics = modifier.modify(batch) + + # reverse_kl = actor - teacher + # sample 0: [-1.0 - (-1.5), -2.0 - (-1.5), -1.5 - (-1.5), -0.5 - (-1.5)] + # = [0.5, -0.5, 0.0, 1.0] + # sample 1: [-2.0 - (-1.5), -1.0 - (-1.5), -3.0 - (-1.5), -1.0 - (-1.5)] + # = [-0.5, 0.5, -1.5, 0.5] + # kl_penalty = -1.0 * reverse_kl (masked) + # sample 0: [0, 0.5, 0.0, -1.0] (first token masked) + # sample 1: [0, 0, 1.5, -0.5] (first two masked) + + expected_weight = torch.tensor([ + [0.0, 0.5, 0.0, -1.0], + [0.0, 0.0, 1.5, -0.5], + ]) + + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) + + # Check metrics + assert "kl_mean" in metrics + assert "kl_std" in metrics + assert "kl_penalty_mean" in metrics + + +def test_kl_credit_modifier_with_existing_advantage(): + """Test that KL penalty is added to existing advantages.""" + # Start with non-zero advantages + weight = torch.tensor([ + [0.0, 1.0, 1.0, 1.0], # sample 0 + [0.0, 0.0, 2.0, 2.0], # sample 1 + ]) + actor_logps = torch.tensor([ + [-1.0, -1.0, -1.0, -1.0], + [-1.0, -1.0, -1.0, -1.0], + ]) + teacher_logps = torch.tensor([ + [-2.0, -2.0, -2.0, -2.0], # actor closer to 0 = higher prob + [-2.0, -2.0, -2.0, -2.0], + ]) + action_mask = torch.tensor([ + [0, 1, 1, 1], + [0, 0, 1, 1], + ]) + + batch = _make_batch( + weight=weight, + actor_logps=actor_logps, + teacher_logps=teacher_logps, + action_mask=action_mask, + ) + + modifier = KLCreditModifier(coeff=1.0) + modified_batch, _ = modifier.modify(batch) + + # reverse_kl = -1.0 - (-2.0) = 1.0 everywhere + # kl_penalty = -1.0 * 1.0 = -1.0 (penalty for being overconfident vs teacher) + # Modified weight = original + penalty (masked) + expected_weight = torch.tensor([ + [0.0, 1.0 - 1.0, 1.0 - 1.0, 1.0 - 1.0], # [0, 0, 0, 0] + [0.0, 0.0, 2.0 - 1.0, 2.0 - 1.0], # [0, 0, 1, 1] + ]) + + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) + + +def test_kl_credit_modifier_coeff(): + """Test that coeff scales the KL penalty.""" + weight = torch.zeros(1, 3) + actor_logps = torch.tensor([[-1.0, -1.0, -1.0]]) + teacher_logps = torch.tensor([[-2.0, -2.0, -2.0]]) + action_mask = torch.tensor([[1, 1, 1]]) + + batch = _make_batch( + weight=weight, + actor_logps=actor_logps, + teacher_logps=teacher_logps, + action_mask=action_mask, + ) + + # With coeff=2.0, penalty should be doubled + modifier = KLCreditModifier(coeff=2.0) + modified_batch, _ = modifier.modify(batch) + + # reverse_kl = 1.0, penalty = -2.0 * 1.0 = -2.0 + expected_weight = torch.tensor([[-2.0, -2.0, -2.0]]) + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) + + +def test_kl_credit_modifier_missing_actor_logps(): + """Test that missing actor_logps raises KeyError.""" + batch = { + "weight": torch.zeros(1, 3), + "teacher_logps": torch.zeros(1, 3), + "action_mask": torch.ones(1, 3), + } + + modifier = KLCreditModifier() + with pytest.raises(KeyError, match="actor_logps"): + modifier.modify(batch) + + +def test_kl_credit_modifier_missing_teacher_logps(): + """Test that missing teacher_logps raises KeyError.""" + batch = { + "weight": torch.zeros(1, 3), + "actor_logps": torch.zeros(1, 3), + "action_mask": torch.ones(1, 3), + } + + modifier = KLCreditModifier() + with pytest.raises(KeyError, match="teacher_logps"): + modifier.modify(batch) + + +def test_kl_credit_modifier_zeroes_prompt_tokens(): + """ + Test that prompt tokens (action_mask=0) always have zero weight, + even if the input weight tensor had non-zero values there. + + This prevents prompt-length-dependent loss scaling. + """ + # Input weight has non-zero values on ALL tokens (as might happen + # if trajectory-level advantages were broadcast without masking) + weight = torch.tensor([ + [5.0, 5.0, 5.0, 5.0], # advantage of 5.0 broadcast to all tokens + ]) + actor_logps = torch.tensor([[-1.0, -1.0, -1.0, -1.0]]) + teacher_logps = torch.tensor([[-1.0, -1.0, -1.0, -1.0]]) # identical, so KL=0 + action_mask = torch.tensor([ + [0, 0, 1, 1], # first two tokens are prompt + ]) + + batch = _make_batch( + weight=weight, + actor_logps=actor_logps, + teacher_logps=teacher_logps, + action_mask=action_mask, + ) + + modifier = KLCreditModifier(coeff=1.0) + modified_batch, _ = modifier.modify(batch) + + # Key assertion: prompt tokens (positions 0, 1) must have zero weight, + # even though input weight was 5.0 there + # Action tokens (positions 2, 3) should have weight=5.0 (advantage + 0 KL penalty) + expected_weight = torch.tensor([[0.0, 0.0, 5.0, 5.0]]) + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) + + +def test_kl_credit_modifier_turn_level_weight(): + """ + Test that turn-level (1D) weights are broadcast to per-token advantages. + + Credit assigners produce one weight per turn/step, but KL is per-token. + The modifier broadcasts the scalar weight and adds per-token KL. + """ + # Turn-level weight: one value per sample (shape [B]) + weight = torch.tensor([2.0, -1.0]) # two samples with advantages 2.0 and -1.0 + + # Token-level tensors: shape [B, T] + actor_logps = torch.tensor([ + [-1.0, -1.0, -1.0, -1.0], + [-2.0, -2.0, -2.0, -2.0], + ]) + teacher_logps = torch.tensor([ + [-1.5, -1.5, -1.5, -1.5], # KL = -1.0 - (-1.5) = 0.5, penalty = -0.5 + [-1.5, -1.5, -1.5, -1.5], # KL = -2.0 - (-1.5) = -0.5, penalty = 0.5 + ]) + action_mask = torch.tensor([ + [0, 1, 1, 1], # sample 0: 3 action tokens (positions 1,2,3) + [0, 0, 1, 1], # sample 1: 2 action tokens (positions 2,3) + ]) + + batch = { + "weight": weight, # [B] - turn-level + "actor_logps": actor_logps, + "teacher_logps": teacher_logps, + "action_mask": action_mask, + } + + modifier = KLCreditModifier(coeff=1.0) + modified_batch, _ = modifier.modify(batch) + + # Output is per-token [B, T] with masked prompt positions. + # Sample 0: weight=2.0 + kl_penalty=[-0.5, -0.5, -0.5] -> [1.5, 1.5, 1.5] + # Sample 1: weight=-1.0 + kl_penalty=[0.5, 0.5] -> [-0.5, -0.5] + expected_weight = torch.tensor([ + [0.0, 1.5, 1.5, 1.5], + [0.0, 0.0, -0.5, -0.5], + ]) + + assert modified_batch["weight"].shape == (2, 4) + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) + + +def test_kl_credit_modifier_completion_only_weight(): + """ + Test that completion-only weights [B, C] stay in completion-only format. + + When weight is stored compactly for just completion tokens (C < T), + the modifier extracts KL for those positions and adds it, keeping [B, C] shape. + """ + # Completion-only weight: [B, C] where C is number of completion tokens + # Sample 0 has 3 completion tokens, sample 1 has 2 + # But they're padded to same length C=3 + weight = torch.tensor([ + [1.0, 2.0, 3.0], # sample 0: weights for 3 action tokens + [4.0, 5.0, 0.0], # sample 1: weights for 2 action tokens (padded) + ]) + + # Full sequence tensors: [B, T] where T=6 (prompt + completion) + actor_logps = torch.tensor([ + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], + ]) + teacher_logps = torch.tensor([ + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], # KL = 0 + [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0], + ]) + action_mask = torch.tensor([ + [0, 0, 0, 1, 1, 1], # sample 0: 3 prompt, 3 completion + [0, 0, 0, 0, 1, 1], # sample 1: 4 prompt, 2 completion + ]) + + batch = { + "weight": weight, # [B, C=3] - completion-only + "actor_logps": actor_logps, + "teacher_logps": teacher_logps, + "action_mask": action_mask, + } + + modifier = KLCreditModifier(coeff=1.0) + modified_batch, _ = modifier.modify(batch) + + # With KL=0, modified_weight = weight (unchanged since kl_penalty is 0) + # Output stays in completion-only format [B, C] + expected_weight = torch.tensor([ + [1.0, 2.0, 3.0], # sample 0: unchanged + [4.0, 5.0, 0.0], # sample 1: unchanged (padding preserved) + ]) + + assert modified_batch["weight"].shape == (2, 3) # stays completion-only + assert torch.allclose(modified_batch["weight"], expected_weight, atol=1e-6) diff --git a/tests/test_loss.py b/tests/test_loss.py index 6832009..078d7ae 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -257,6 +257,31 @@ def test_gspo_loss_length_normalize_affects_ratio(): assert torch.allclose(loss_norm, expected_loss_norm, atol=1e-4) +def test_gspo_loss_token_advantages_masked_mean(): + logits = torch.tensor([[ + [0.0, 0.0], # predicts token at pos 1 + [2.0, 0.0], # predicts token at pos 2 + [0.0, 0.0], # unused + ]], dtype=torch.float32) + input_ids = torch.tensor([[0, 1, 0]], dtype=torch.long) + action_mask = torch.tensor([[0, 1, 1]], dtype=torch.float32) + + logp_action = compute_logp_action(logits, input_ids, action_mask) + batch = { + "input_ids": input_ids, + "action_mask": action_mask, + "weight": torch.tensor([[0.0, 2.0, -1.0]], dtype=torch.float32), + "old_logp_action": logp_action, + } + + loss_fn = ClippedSurrogateLoss(clip_eps_low=1.0, clip_eps_high=1.0, length_normalize=False) + loss, _ = loss_fn.compute(logits, batch) + + # ratio = 1.0, so loss is the masked mean of advantages over action tokens. + # (2.0 + -1.0) / 2 = 0.5 => loss = -0.5 + assert torch.allclose(loss, torch.tensor(-0.5), atol=1e-4) + + def test_gspo_loss_upper_clip_positive_advantage(): logits = torch.tensor([[[0.0, 0.0], [0.0, 0.0]]], dtype=torch.float32) input_ids = torch.tensor([[0, 1]], dtype=torch.long)