diff --git a/claas/core/types.py b/claas/core/types.py index 0b9e106..55929a7 100644 --- a/claas/core/types.py +++ b/claas/core/types.py @@ -32,6 +32,8 @@ class TrainingConfig: max_grad_norm: float = 1.0 kl_reg_weight: float = 0.0 teacher_top_k: int = 100 + steps_per_batch: int = 4 + feedback_repetitions: int = 1 class SDPOLossInput(BaseModel): diff --git a/claas/eval/README.md b/claas/eval/README.md index 92438fe..db505b7 100644 --- a/claas/eval/README.md +++ b/claas/eval/README.md @@ -26,15 +26,6 @@ metrics: # metrics to evaluate per step num_steps: 20 batch_size: 4 -steps_per_batch: 4 # gradient updates per batch -feedback_repetitions: 1 # times to repeat feedback string -training: # forwarded to /v1/feedback training config - learning_rate: 3e-5 - alpha: 0.5 - is_clip: 5.0 - max_grad_norm: 1.0 - kl_reg_weight: 0.0 - teacher_top_k: 100 collapse_steps: [0, 5, 10, 15, 19] # steps where collapse metric runs plots: true # generate matplotlib plots seed: 42 @@ -42,6 +33,16 @@ lora_id_prefix: eval output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ} openclaw_url: http://localhost:18789 # OpenClaw gateway (null = use CLaaS API directly) + +training: # forwarded to /v1/feedback TrainingConfig + learning_rate: 3e-5 + alpha: 0.5 + is_clip: 5.0 + max_grad_norm: 1.0 + kl_reg_weight: 0.0 + teacher_top_k: 100 + steps_per_batch: 4 # gradient updates per batch + feedback_repetitions: 1 # times to repeat feedback string ``` ### Overriding config via CLI diff --git a/claas/eval/configs/base.yaml b/claas/eval/configs/base.yaml index c16f34d..26e658b 100644 --- a/claas/eval/configs/base.yaml +++ b/claas/eval/configs/base.yaml @@ -22,8 +22,12 @@ plots: true num_steps: 20 batch_size: 4 -steps_per_batch: 4 -feedback_repetitions: 1 +seed: 42 +lora_id_prefix: eval +output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ} + +openclaw_url: http://localhost:18789 + training: learning_rate: 3e-5 alpha: 0.5 @@ -31,8 +35,5 @@ training: max_grad_norm: 1.0 kl_reg_weight: 0.0 teacher_top_k: 100 -seed: 42 -lora_id_prefix: eval -output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ} - -openclaw_url: http://localhost:18789 + steps_per_batch: 4 + feedback_repetitions: 1 diff --git a/claas/eval/runner.py b/claas/eval/runner.py index ac59041..4c537b9 100644 --- a/claas/eval/runner.py +++ b/claas/eval/runner.py @@ -81,6 +81,7 @@ async def _submit_feedback( adv_abs_mean_raw=metadata["adv_abs_mean_raw"], completion_len=metadata["completion_len"], batch_size=metadata["batch_size"], + steps_per_batch_applied=metadata["steps_per_batch_applied"], ) return LocalDistillMetrics( @@ -88,6 +89,7 @@ async def _submit_feedback( kl_reg=metadata.get("kl_reg"), mean_is_ratio=metadata.get("mean_is_ratio"), clip_fraction=metadata.get("clip_fraction"), + steps_per_batch_applied=metadata["steps_per_batch_applied"], ) @@ -190,6 +192,7 @@ def _load_completed_steps(output_dir: str, preference: str) -> list[StepResult]: prompt_used=data["prompt_used"], response_text=data.get("response_text"), timing_s=data.get("timing_s", 0.0), + sub_step_count=data.get("sub_step_count", 1), )) return steps @@ -362,8 +365,8 @@ async def run_preference_experiment( for step in range(resume_from, config.num_steps): step_start = time.perf_counter() - # Determine feedback string - feedback_str = " ".join([pref.feedback_string] * config.feedback_repetitions) + # Feedback repetition is a training concern configured via TrainingConfig. + feedback_str = pref.feedback_string # Collect samples for this step (batch_size >= 1) samples: list[FeedbackItem] = [] @@ -398,29 +401,21 @@ async def run_preference_experiment( if response_text is None: response_text = "I'd be happy to help you with that." - # Submit feedback — possibly multiple gradient steps on same batch + # Submit feedback for this step. Training engine applies steps_per_batch. sdpo_metrics = None - sub_steps_completed = 0 if samples: - for sub_step in range(config.steps_per_batch): - try: - sdpo_metrics = await _submit_feedback( - config, actual_lora_id, samples, - ) - sub_steps_completed += 1 - except (httpx.HTTPError, KeyError) as e: - logger.warning( - "[%s] Step %d sub-step %d feedback failed: %s", - pref.name, step, sub_step, e, - ) - break - - if config.steps_per_batch > 1: - logger.info( - "[%s] Step %d: %d sub-steps completed", - pref.name, step, sub_steps_completed, + try: + sdpo_metrics = await _submit_feedback( + config, actual_lora_id, samples, + ) + except (httpx.HTTPError, KeyError) as e: + logger.warning( + "[%s] Step %d feedback failed: %s", + pref.name, step, e, ) + sub_step_count = sdpo_metrics.steps_per_batch_applied if sdpo_metrics else 1 + # Measure eval try: eval_metrics = await _measure_eval_metrics( @@ -447,7 +442,7 @@ async def run_preference_experiment( ], response_text=response_text if needs_generation else None, timing_s=timing_s, - sub_step_count=sub_steps_completed if sub_steps_completed > 0 else 1, + sub_step_count=sub_step_count, ) result.steps.append(step_result) diff --git a/claas/eval/types.py b/claas/eval/types.py index 02978dc..a43fce5 100644 --- a/claas/eval/types.py +++ b/claas/eval/types.py @@ -94,8 +94,6 @@ class EvalConfig: openclaw_url: Optional[str] = None base_model: str = "Qwen/Qwen3-8B" batch_size: int = 4 - steps_per_batch: int = 4 - feedback_repetitions: int = 1 training: TrainingConfig = field(default_factory=TrainingConfig) @@ -117,6 +115,7 @@ class LocalDistillMetrics: kl_reg: float | None mean_is_ratio: float | None clip_fraction: float | None + steps_per_batch_applied: int = 1 @dataclass @@ -131,6 +130,7 @@ class TinkerDistillMetrics: adv_abs_mean_raw: float completion_len: int = 0 batch_size: int = 0 + steps_per_batch_applied: int = 1 @dataclass diff --git a/claas/training/distillation.py b/claas/training/distillation.py index 5487303..3e2e007 100644 --- a/claas/training/distillation.py +++ b/claas/training/distillation.py @@ -7,7 +7,7 @@ import os import tempfile from contextlib import nullcontext as _nullcontext -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, TypedDict, cast import torch @@ -35,6 +35,17 @@ logger = logging.getLogger(__name__) +class PreparedSample(TypedDict): + full_ids: torch.Tensor + response_ids: torch.Tensor + response_start: int + response_mask: torch.Tensor + base_logprobs: torch.Tensor + teacher_logprobs: torch.Tensor + teacher_indices: torch.Tensor + behavior_logprobs: torch.Tensor + + class DistillationTrainer: """Runs one SDPO distillation update using a loaded base model.""" @@ -219,6 +230,24 @@ def _build_self_teacher_topk( torch.cuda.empty_cache() return top_logprobs, top_indices, teacher_scored_text + def _compute_student_response_logprobs( + self, + model: "PeftModel | PeftMixedModel", + full_ids: torch.Tensor, + response_ids: torch.Tensor, + response_start: int, + ) -> torch.Tensor: + """Compute student logprobs for response tokens under the current adapter.""" + with torch.no_grad(): + student_output = model(input_ids=full_ids) + student_logits = student_output.logits[:, response_start - 1 : -1, :] + student_logprobs = self.functional.log_softmax(student_logits, dim=-1).gather( + -1, response_ids.unsqueeze(-1) + ).squeeze(-1) + del student_output, student_logits + torch.cuda.empty_cache() + return student_logprobs.to(dtype=torch.float32).detach() + def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: """Run one SDPO distillation step. @@ -257,13 +286,10 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: ) self._load_optimizer_state(lora_local_path, optimizer) - batch_loss_tensors: list[torch.Tensor] = [] - batch_distill_loss: list[float] = [] - batch_kl_reg: list[float] = [] - batch_mean_is_ratio: list[float] = [] - batch_clip_fraction: list[float] = [] + prepared_samples: list[PreparedSample] = [] batch_teacher_scored_texts: list[str] = [] - tokens_processed = 0 + tokens_per_step = 0 + repeated_feedback_count = config.feedback_repetitions for sample in payload.samples: prompt_ids = torch.tensor( @@ -276,79 +302,124 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: device=self.device, dtype=torch.int64, ) + response_token_count = response_ids.shape[-1] + if len(sample.response_logprobs) != response_token_count: + raise ValueError("response_logprobs length must match response token length") full_ids = torch.cat([prompt_ids, response_ids], dim=-1) response_start = prompt_ids.shape[-1] - response_token_count = response_ids.shape[-1] - tokens_processed += int(response_token_count) - - response_mask = torch.zeros(1, full_ids.shape[-1], device=self.device) - response_mask[:, response_start:] = 1.0 + tokens_per_step += int(response_token_count) with torch.no_grad(), model.disable_adapter(): base_output = self.base_model(input_ids=full_ids) base_logits = base_output.logits[:, response_start - 1 : -1, :] base_logprobs = self.functional.log_softmax(base_logits, dim=-1).gather( - -1, response_ids[:, :response_token_count].unsqueeze(-1) + -1, response_ids.unsqueeze(-1) ).squeeze(-1) - del base_output, base_logits torch.cuda.empty_cache() - student_output = model(input_ids=full_ids) - student_logits = student_output.logits[:, response_start - 1 : -1, :].contiguous() - del student_output - - old_student_logprobs = torch.tensor( - sample.response_logprobs, - dtype=torch.float32, - device=self.device, - ).unsqueeze(0) - if old_student_logprobs.shape[1] > response_token_count: - old_student_logprobs = old_student_logprobs[:, :response_token_count] - elif old_student_logprobs.shape[1] < response_token_count: - raise ValueError("response_logprobs length must match response token length") - + repeated_feedback = " ".join([sample.feedback] * repeated_feedback_count) teacher_logprobs, teacher_indices, teacher_scored_text = self._build_self_teacher_topk( sample.user_prompt, - sample.feedback, + repeated_feedback, response_ids, config.teacher_top_k, system_prompt=sample.system_prompt, peft_model=model, ) - if teacher_logprobs.shape[0] != response_token_count: raise ValueError("teacher logprob sequence length must match response length") - loss_input = SDPOLossInput( - student_logits=student_logits, - teacher_logprobs=teacher_logprobs.unsqueeze(0), - teacher_indices=teacher_indices.unsqueeze(0), - base_logprobs=base_logprobs, - response_mask=response_mask[:, response_start:], - old_student_logprobs=old_student_logprobs, - response_ids=response_ids[:, :response_token_count], - training=config, + behavior_logprobs = torch.tensor( + sample.response_logprobs, + dtype=torch.float32, + device=self.device, + ).unsqueeze(0) + + prepared_samples.append( + PreparedSample( + full_ids=full_ids, + response_ids=response_ids, + response_start=int(response_start), + response_mask=torch.ones( + (1, response_token_count), + device=self.device, + dtype=torch.float32, + ), + base_logprobs=base_logprobs, + teacher_logprobs=teacher_logprobs.unsqueeze(0), + teacher_indices=teacher_indices.unsqueeze(0), + behavior_logprobs=behavior_logprobs, + ) ) - loss_dict = compute_sdpo_loss(loss_input) - batch_loss_tensors.append(loss_dict["loss"]) - batch_distill_loss.append(loss_dict["distill_loss"]) - batch_kl_reg.append(loss_dict["kl_reg"]) - batch_mean_is_ratio.append(loss_dict["mean_is_ratio"]) - batch_clip_fraction.append(loss_dict["clip_fraction"]) batch_teacher_scored_texts.append(teacher_scored_text) - del full_ids, prompt_ids, response_ids, response_mask - del student_logits, base_logprobs, old_student_logprobs - del teacher_logprobs, teacher_indices, loss_input - - mean_loss = torch.stack(batch_loss_tensors).mean() - mean_loss.backward() + del prompt_ids + + step_metrics: list[dict[str, float | int]] = [] + + for step_idx in range(config.steps_per_batch): + batch_loss_tensors: list[torch.Tensor] = [] + batch_distill_loss: list[float] = [] + batch_kl_reg: list[float] = [] + batch_mean_is_ratio: list[float] = [] + batch_clip_fraction: list[float] = [] + + for sample_state in prepared_samples: + student_output = model(input_ids=sample_state["full_ids"]) + student_logits = student_output.logits[ + :, sample_state["response_start"] - 1 : -1, : + ].contiguous() + del student_output + + loss_input = SDPOLossInput( + student_logits=student_logits, + teacher_logprobs=sample_state["teacher_logprobs"], + teacher_indices=sample_state["teacher_indices"], + base_logprobs=sample_state["base_logprobs"], + response_mask=sample_state["response_mask"], + old_student_logprobs=sample_state["behavior_logprobs"], + response_ids=sample_state["response_ids"], + training=config, + ) + loss_dict = compute_sdpo_loss(loss_input) + batch_loss_tensors.append(loss_dict["loss"]) + batch_distill_loss.append(loss_dict["distill_loss"]) + batch_kl_reg.append(loss_dict["kl_reg"]) + batch_mean_is_ratio.append(loss_dict["mean_is_ratio"]) + batch_clip_fraction.append(loss_dict["clip_fraction"]) + + del student_logits, loss_input + + mean_loss = torch.stack(batch_loss_tensors).mean() + mean_loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_(lora_params, config.max_grad_norm) + optimizer.step() + optimizer.zero_grad() + + grad_norm_value = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) + step_metrics.append( + { + "step": step_idx + 1, + "total_loss": mean_loss.item(), + "distill_loss": sum(batch_distill_loss) / len(batch_distill_loss), + "kl_reg": sum(batch_kl_reg) / len(batch_kl_reg), + "mean_is_ratio": sum(batch_mean_is_ratio) / len(batch_mean_is_ratio), + "clip_fraction": sum(batch_clip_fraction) / len(batch_clip_fraction), + "grad_norm": grad_norm_value, + } + ) - grad_norm = torch.nn.utils.clip_grad_norm_(lora_params, config.max_grad_norm) - optimizer.step() - optimizer.zero_grad() + if step_idx < config.steps_per_batch - 1: + for sample_state in prepared_samples: + sample_state["behavior_logprobs"] = self._compute_student_response_logprobs( + model=model, + full_ids=sample_state["full_ids"], + response_ids=sample_state["response_ids"], + response_start=sample_state["response_start"], + ) model.gradient_checkpointing_disable() @@ -363,28 +434,25 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: finally: cleanup_local_lora(save_dir) - total_loss = mean_loss.item() - distill_loss = sum(batch_distill_loss) / len(batch_distill_loss) - kl_reg = sum(batch_kl_reg) / len(batch_kl_reg) - mean_is_ratio = sum(batch_mean_is_ratio) / len(batch_mean_is_ratio) - clip_fraction = sum(batch_clip_fraction) / len(batch_clip_fraction) - grad_norm_value = grad_norm.item() if hasattr(grad_norm, "item") else float(grad_norm) + final_step_metrics = step_metrics[-1] - del model, optimizer, batch_loss_tensors + del model, optimizer torch.cuda.empty_cache() return DistillResponse.model_validate( { "lora_id": new_lora_id, "metadata": { - "total_loss": total_loss, - "distill_loss": distill_loss, - "kl_reg": kl_reg, - "mean_is_ratio": mean_is_ratio, - "clip_fraction": clip_fraction, - "grad_norm": grad_norm_value, - "tokens_processed": tokens_processed, + "total_loss": final_step_metrics["total_loss"], + "distill_loss": final_step_metrics["distill_loss"], + "kl_reg": final_step_metrics["kl_reg"], + "mean_is_ratio": final_step_metrics["mean_is_ratio"], + "clip_fraction": final_step_metrics["clip_fraction"], + "grad_norm": final_step_metrics["grad_norm"], + "tokens_processed": tokens_per_step * config.steps_per_batch, "batch_size": len(payload.samples), + "steps_per_batch_applied": config.steps_per_batch, + "per_step_metrics": step_metrics, "teacher_scored_texts": batch_teacher_scored_texts, }, } diff --git a/claas/training/engine/tinker/engine.py b/claas/training/engine/tinker/engine.py index c5db847..e3a0dc6 100644 --- a/claas/training/engine/tinker/engine.py +++ b/claas/training/engine/tinker/engine.py @@ -14,7 +14,7 @@ import logging import os from datetime import datetime, timezone -from typing import Any +from typing import Any, TypedDict from urllib.parse import quote import httpx @@ -161,27 +161,32 @@ async def distill( ) -> DistillResponse: """Run one SDPO distillation step entirely through the Tinker SDK. - Supports batched samples: all samples are processed concurrently, - then a single forward_backward + optim_step is applied. + Supports batched samples with configurable multi-step updates. + Sample preparation (teacher signals) runs once, then each optimizer + step rebuilds importance-weighted datums using current behavior logprobs. Per-sample flow: 1. Tokenize prompt + response directly (no chat wrapping) - 2. Compute student (rollout) logprobs via SamplingClient + 2. Read initial student (rollout) logprobs from cached response 3. Build teacher prompt using build_teacher_messages (feedback reprompt) 4. Compute teacher logprobs (base model conditioned on feedback) 5. Derive advantages with adaptive KL scaling 6. Build Datum with right-shifted alignment Batch flow: - 7. forward_backward with importance_sampling loss (all datums) - 8. optim_step with AdamW - 9. Save checkpoint and update state + 7. Repeat `steps_per_batch` times: + - forward_backward with importance_sampling loss + - optim_step with AdamW + - recompute behavior logprobs for next step (if needed) + 8. Save checkpoint and update state """ entry = _require_entry(payload.lora_id, self._state_path) base_model = entry.base_model lr = payload.training.learning_rate kl_coef = payload.training.alpha + steps_per_batch = payload.training.steps_per_batch + feedback_repetitions = payload.training.feedback_repetitions # ── Phase 1: Setup (once per batch) ── training_client = await self.service.create_training_client_from_state_async( @@ -195,25 +200,69 @@ async def distill( # ── Phase 2: Per-sample processing (concurrent) ── tasks = [ - _build_sample_datum( - sample, tokenizer, teacher_sampling, kl_coef + _prepare_sample_inputs( + sample=sample, + tokenizer=tokenizer, + teacher_sampling=teacher_sampling, + feedback_repetitions=feedback_repetitions, ) for sample in payload.samples ] results = await asyncio.gather(*tasks) - datums = [r[0] for r in results] - sample_metrics = [r[1] for r in results] + prepared_samples = [r[0] for r in results] + behavior_logprobs = [r[1] for r in results] + + # ── Phase 3: Multi-step training ── + step_metrics: list[dict[str, float | int]] = [] + final_fwd_metrics: dict[str, object] | None = None + for step_idx in range(steps_per_batch): + datum_metrics = [ + _build_sample_datum( + prepared=prepared, + student_logprobs=behavior_logprobs[sample_idx], + kl_coef=kl_coef, + ) + for sample_idx, prepared in enumerate(prepared_samples) + ] + datums = [dm[0] for dm in datum_metrics] + sample_metrics = [dm[1] for dm in datum_metrics] + + fwd_bwd = await training_client.forward_backward_async( + datums, "importance_sampling" + ) + await training_client.optim_step_async( + T.AdamParams(learning_rate=lr, beta1=0.9, beta2=0.95) + ) + if hasattr(fwd_bwd, "metrics") and fwd_bwd.metrics: + final_fwd_metrics = dict(fwd_bwd.metrics) + + n = len(sample_metrics) + total_completion_len = sum(m["completion_len"] for m in sample_metrics) + avg = lambda key: sum(m[key] for m in sample_metrics) / n # noqa: E731 + step_metrics.append( + { + "step": step_idx + 1, + "batch_size": n, + "completion_len": total_completion_len, + "effective_kl_coef": avg("effective_kl_coef"), + "kl_gain": avg("kl_gain"), + "adv_mean": avg("adv_mean"), + "adv_abs_mean": avg("adv_abs_mean"), + "kl_mean": avg("kl_mean"), + "adv_abs_mean_raw": avg("adv_abs_mean_raw"), + } + ) - # ── Phase 3: Training step (once per batch) ── - fwd_bwd = await training_client.forward_backward_async( - datums, "importance_sampling" - ) - await training_client.optim_step_async( - T.AdamParams(learning_rate=lr, beta1=0.9, beta2=0.95) - ) + if step_idx < steps_per_batch - 1: + student_sampling = await training_client.save_weights_and_get_sampling_client_async() + behavior_logprobs = await _compute_student_logprobs_for_batch( + student_sampling=student_sampling, + prepared_samples=prepared_samples, + ) - # ── Phase 4: Save & return (once per batch) ── - new_step = entry.step + 1 + # ── Phase 4: Save & return (once per request) ── + final_step = step_metrics[-1] + new_step = entry.step + steps_per_batch checkpoint_name = f"step-{new_step}" save_result = await _await_api_future(await training_client.save_state_async(checkpoint_name)) @@ -232,65 +281,69 @@ async def distill( path=self._state_path, ) - # Aggregate metrics across samples. - n = len(sample_metrics) - total_completion_len = sum(m["completion_len"] for m in sample_metrics) - avg = lambda key: sum(m[key] for m in sample_metrics) / n # noqa: E731 - metadata: dict[str, object] = { "step": new_step, "tinker_path": save_result.path, "sampler_weights_path": sampler_weights_path, - "batch_size": n, - "completion_len": total_completion_len, - "effective_kl_coef": avg("effective_kl_coef"), - "kl_gain": avg("kl_gain"), - "adv_mean": avg("adv_mean"), - "adv_abs_mean": avg("adv_abs_mean"), - "kl_mean": avg("kl_mean"), - "adv_abs_mean_raw": avg("adv_abs_mean_raw"), + "batch_size": final_step["batch_size"], + "completion_len": final_step["completion_len"], + "effective_kl_coef": final_step["effective_kl_coef"], + "kl_gain": final_step["kl_gain"], + "adv_mean": final_step["adv_mean"], + "adv_abs_mean": final_step["adv_abs_mean"], + "kl_mean": final_step["kl_mean"], + "adv_abs_mean_raw": final_step["adv_abs_mean_raw"], "lr": lr, "loss_fn": "importance_sampling", "timestamp": datetime.now(timezone.utc).isoformat(), - "teacher_scored_texts": [m["teacher_scored_text"] for m in sample_metrics], + "teacher_scored_texts": [p["teacher_scored_text"] for p in prepared_samples], + "steps_per_batch_applied": steps_per_batch, + "per_step_metrics": step_metrics, } - if hasattr(fwd_bwd, "metrics") and fwd_bwd.metrics: - metadata["tinker_fwd_metrics"] = fwd_bwd.metrics + if final_fwd_metrics is not None: + metadata["tinker_fwd_metrics"] = final_fwd_metrics return DistillResponse(lora_id=payload.lora_id, metadata=metadata) -async def _build_sample_datum( +class PreparedSample(TypedDict): + full_tokens: list[int] + input_tokens: list[int] + target_tokens: list[int] + prompt_len: int + completion_len: int + teacher_logprobs: list[float] + teacher_scored_text: str + + +async def _prepare_sample_inputs( + *, sample: DistillBatchItem, tokenizer: Any, teacher_sampling: Any, - kl_coef: float, -) -> tuple[T.Datum, dict[str, float | str]]: - """Process a single sample into a Datum and per-sample metrics. - - This is the per-sample logic extracted from ``distill()`` so that - multiple samples can be processed concurrently via ``asyncio.gather``. - """ - # ── Tokenize prompt + response directly (matching local worker) ── + feedback_repetitions: int, +) -> tuple[PreparedSample, list[float]]: + """Prepare sample-invariant tensors and initial behavior logprobs.""" prompt_tokens = list(sample.prompt_token_ids) response_tokens = list(sample.response_token_ids) - full_tokens = prompt_tokens + response_tokens - prompt_len = len(prompt_tokens) completion_len = len(response_tokens) - - # ── Validate and use provided response logprobs ── if len(sample.response_logprobs) != completion_len: raise ValueError( f"response_logprobs length ({len(sample.response_logprobs)}) != " f"completion_len ({completion_len})" ) - student_logprobs = list(sample.response_logprobs) - # ── Build teacher prompt (matching local worker: build_teacher_messages) ── - teacher_prompt_source = sample.user_prompt + full_tokens = prompt_tokens + response_tokens + prompt_len = len(prompt_tokens) + input_tokens = full_tokens[:-1] + target_tokens = full_tokens[1:] + + repeated_feedback = " ".join([sample.feedback] * feedback_repetitions) teacher_messages = build_teacher_messages( - teacher_prompt_source, sample.feedback, system_prompt=sample.system_prompt + sample.user_prompt, + repeated_feedback, + system_prompt=sample.system_prompt, ) template_messages = teacher_messages_to_chat_template(teacher_messages) teacher_prompt_text = tokenizer.apply_chat_template( @@ -307,13 +360,40 @@ async def _build_sample_datum( teacher_full = T.ModelInput.from_ints(teacher_full_tokens) teacher_scored_text = tokenizer.decode(teacher_full_tokens, skip_special_tokens=False) - # ── Compute teacher logprobs (base model = self-distillation) ── teacher_logprobs_full = await teacher_sampling.compute_logprobs_async(teacher_full) teacher_logprobs = _slice_completion_logprobs( - teacher_logprobs_full, teacher_prompt_len, completion_len + teacher_logprobs_full, + teacher_prompt_len, + completion_len, + ) + + prepared = PreparedSample( + full_tokens=full_tokens, + input_tokens=input_tokens, + target_tokens=target_tokens, + prompt_len=prompt_len, + completion_len=completion_len, + teacher_logprobs=teacher_logprobs, + teacher_scored_text=teacher_scored_text, ) + return prepared, list(sample.response_logprobs) - # ── Compute advantages with adaptive KL scaling ── + +def _build_sample_datum( + *, + prepared: PreparedSample, + student_logprobs: list[float], + kl_coef: float, +) -> tuple[T.Datum, dict[str, float]]: + """Build a Tinker datum from prepared teacher signals + current behavior policy.""" + completion_len = prepared["completion_len"] + if len(student_logprobs) != completion_len: + raise ValueError( + f"student_logprobs length ({len(student_logprobs)}) != " + f"completion_len ({completion_len})" + ) + + teacher_logprobs = prepared["teacher_logprobs"] raw_kl_deltas = [t - s for s, t in zip(student_logprobs, teacher_logprobs, strict=True)] adv_abs_mean_raw = sum(abs(d) for d in raw_kl_deltas) / max(len(raw_kl_deltas), 1) @@ -321,49 +401,68 @@ async def _build_sample_datum( if adv_abs_mean_raw > 0: gain = min(max(_TARGET_ADV_ABS_MEAN / adv_abs_mean_raw, 1.0), _MAX_KL_GAIN) effective_kl_coef = kl_coef * gain - advantages = [ effective_kl_coef * (t - s) for s, t in zip(student_logprobs, teacher_logprobs, strict=True) ] - # ── Build Datum with right-shifted alignment ── - input_tokens = full_tokens[:-1] - target_tokens = full_tokens[1:] - input_model_input = T.ModelInput.from_ints(input_tokens) - - full_logprobs = [0.0] * prompt_len + student_logprobs - full_advantages = [0.0] * prompt_len + advantages + full_logprobs = [0.0] * prepared["prompt_len"] + student_logprobs + full_advantages = [0.0] * prepared["prompt_len"] + advantages shifted_logprobs = full_logprobs[1:] shifted_advantages = full_advantages[1:] datum = T.Datum( - model_input=input_model_input, + model_input=T.ModelInput.from_ints(prepared["input_tokens"]), loss_fn_inputs={ - "target_tokens": TensorData(data=target_tokens, dtype="int64"), + "target_tokens": TensorData(data=prepared["target_tokens"], dtype="int64"), "logprobs": TensorData(data=shifted_logprobs, dtype="float32"), "advantages": TensorData(data=shifted_advantages, dtype="float32"), }, ) - adv_mean = sum(advantages) / max(len(advantages), 1) - adv_abs_mean = sum(abs(a) for a in advantages) / max(len(advantages), 1) - kl_mean = sum(raw_kl_deltas) / max(len(raw_kl_deltas), 1) - - metrics: dict[str, float | str] = { + metrics = { "completion_len": completion_len, "effective_kl_coef": effective_kl_coef, "kl_gain": gain, - "adv_mean": adv_mean, - "adv_abs_mean": adv_abs_mean, - "kl_mean": kl_mean, + "adv_mean": sum(advantages) / max(len(advantages), 1), + "adv_abs_mean": sum(abs(a) for a in advantages) / max(len(advantages), 1), + "kl_mean": sum(raw_kl_deltas) / max(len(raw_kl_deltas), 1), "adv_abs_mean_raw": adv_abs_mean_raw, - "teacher_scored_text": teacher_scored_text, } - return datum, metrics +async def _compute_student_logprobs_for_batch( + *, + student_sampling: Any, + prepared_samples: list[PreparedSample], +) -> list[list[float]]: + """Recompute behavior logprobs under the updated student policy.""" + tasks = [ + _compute_student_logprobs_for_sample( + student_sampling=student_sampling, + prepared=prepared, + ) + for prepared in prepared_samples + ] + return await asyncio.gather(*tasks) + + +async def _compute_student_logprobs_for_sample( + *, + student_sampling: Any, + prepared: PreparedSample, +) -> list[float]: + """Compute completion logprobs for one sample under current student weights.""" + student_full = T.ModelInput.from_ints(prepared["full_tokens"]) + student_logprobs_full = await student_sampling.compute_logprobs_async(student_full) + return _slice_completion_logprobs( + student_logprobs_full, + prepared["prompt_len"], + prepared["completion_len"], + ) + + def _require_entry(lora_id: str, state_path: str) -> LoraEntry: """Return the state entry for *lora_id* or raise ``FileNotFoundError``.""" entry = get_entry(lora_id, path=state_path) diff --git a/tests/test_eval_config.py b/tests/test_eval_config.py index 930ce8e..9bea834 100644 --- a/tests/test_eval_config.py +++ b/tests/test_eval_config.py @@ -97,9 +97,14 @@ def test_collapse_steps_list() -> None: assert config.collapse_steps == [0, 5, 10] -def test_steps_per_batch_from_config() -> None: - config = _compose_eval_config(["steps_per_batch=3"]) - assert config.steps_per_batch == 3 +def test_training_steps_per_batch_from_config() -> None: + config = _compose_eval_config(["training.steps_per_batch=3"]) + assert config.training.steps_per_batch == 3 + + +def test_eval_top_level_steps_per_batch_rejected() -> None: + with pytest.raises(ConfigCompositionException): + _compose_eval_config(["steps_per_batch=3"]) def test_hydra_config_custom_dir() -> None: diff --git a/tests/test_eval_runner.py b/tests/test_eval_runner.py index 38cfebf..0093504 100644 --- a/tests/test_eval_runner.py +++ b/tests/test_eval_runner.py @@ -34,6 +34,7 @@ def test_tinker_distill_metrics_fields(): assert m.adv_abs_mean_raw == 0.60 assert m.completion_len == 128 assert m.batch_size == 4 + assert m.steps_per_batch_applied == 1 def test_tinker_distill_metrics_defaults(): @@ -48,6 +49,7 @@ def test_tinker_distill_metrics_defaults(): ) assert m.completion_len == 0 assert m.batch_size == 0 + assert m.steps_per_batch_applied == 1 # ── step_result_from_dict deserialization ───────────────────────────── @@ -76,11 +78,13 @@ def test_step_result_from_dict_tinker_metrics(): "adv_abs_mean_raw": 0.55, "completion_len": 64, "batch_size": 4, + "steps_per_batch_applied": 3, }, } result = step_result_from_dict(data) assert isinstance(result.sdpo_metrics, TinkerDistillMetrics) assert result.sdpo_metrics.adv_mean == -0.5 + assert result.sdpo_metrics.steps_per_batch_applied == 3 def test_step_result_from_dict_local_metrics(): @@ -92,11 +96,13 @@ def test_step_result_from_dict_local_metrics(): "kl_reg": 0.02, "mean_is_ratio": 1.01, "clip_fraction": 0.05, + "steps_per_batch_applied": 2, }, } result = step_result_from_dict(data) assert isinstance(result.sdpo_metrics, LocalDistillMetrics) assert result.sdpo_metrics.distill_loss == 0.35 + assert result.sdpo_metrics.steps_per_batch_applied == 2 def test_step_result_from_dict_no_metrics(): @@ -108,28 +114,19 @@ def test_step_result_from_dict_no_metrics(): # ── EvalConfig defaults ────────────────────────────────────────────── -def test_steps_per_batch_default(): - """EvalConfig.steps_per_batch defaults to 4.""" +def test_training_config_default(): + """EvalConfig.training defaults to TrainingConfig defaults.""" config = EvalConfig() - assert config.steps_per_batch == 4 + assert config.training == TrainingConfig() -def test_steps_per_batch_custom(): - """EvalConfig.steps_per_batch can be set.""" - config = EvalConfig(steps_per_batch=3) - assert config.steps_per_batch == 3 - - -def test_feedback_repetitions_default(): - """EvalConfig.feedback_repetitions defaults to 1.""" - config = EvalConfig() - assert config.feedback_repetitions == 1 - - -def test_feedback_repetitions_custom(): - """EvalConfig.feedback_repetitions can be set.""" - config = EvalConfig(feedback_repetitions=4) - assert config.feedback_repetitions == 4 +def test_training_config_custom(): + """EvalConfig.training accepts custom steps/repetitions.""" + config = EvalConfig( + training=TrainingConfig(steps_per_batch=3, feedback_repetitions=4), + ) + assert config.training.steps_per_batch == 3 + assert config.training.feedback_repetitions == 4 def test_training_config_default_type(): diff --git a/tests/test_tinker_engine.py b/tests/test_tinker_engine.py index 7dd35ca..a7ade33 100644 --- a/tests/test_tinker_engine.py +++ b/tests/test_tinker_engine.py @@ -511,6 +511,68 @@ def test_engine_distill_uses_provided_response_logprobs(tinker_engine, mock_trai mock_training_client.save_weights_and_get_sampling_client_async.assert_not_called() +def test_engine_distill_multistep_recomputes_behavior_logprobs(tinker_engine, mock_training_client): + """Multi-step distill recomputes behavior logprobs between optimizer steps.""" + engine, mock_service = tinker_engine + + set_tinker_path("test/lora", "tinker://ckpt", "gpt-oss/GPT-OSS-120B", 32, step=0) + mock_service.create_training_client_from_state_async = AsyncMock( + return_value=mock_training_client + ) + + teacher_sampler = MagicMock() + teacher_sampler.compute_logprobs_async = AsyncMock(return_value=[-0.3] * 100) + mock_service.create_sampling_client_async = AsyncMock(return_value=teacher_sampler) + + student_sampler = MagicMock() + student_sampler.compute_logprobs_async = AsyncMock(return_value=[-0.15] * 100) + mock_training_client.save_weights_and_get_sampling_client_async = AsyncMock( + return_value=student_sampler + ) + + payload = DistillBatchRequestPayload( + lora_id="test/lora", + training=TrainingConfig(steps_per_batch=2), + samples=[ + DistillBatchItem( + prompt="Hello", + response="World", + feedback="Nice", + response_logprobs=[-0.1, -0.2, -0.3, -0.4, -0.5], + prompt_token_ids=[0, 1, 2, 3, 4], + response_token_ids=[0, 1, 2, 3, 4], + user_prompt="Hello", + system_prompt="You are a helpful assistant.", + ) + ], + ) + + result = asyncio.run(engine.distill(payload)) + + assert isinstance(result, DistillResponse) + assert result.metadata["steps_per_batch_applied"] == 2 + assert result.metadata["step"] == 2 + + # Two optimizer updates, one recompute between them. + assert mock_training_client.forward_backward_async.call_count == 2 + assert mock_training_client.optim_step_async.call_count == 2 + mock_training_client.save_weights_and_get_sampling_client_async.assert_called_once() + student_sampler.compute_logprobs_async.assert_called() + + # Final checkpoint reflects cumulative training steps. + mock_training_client.save_state_async.assert_called_once_with("step-2") + entry = get_entry("test/lora") + assert entry is not None + assert entry.step == 2 + + # Step-2 datum should use recomputed student logprobs (different from rollout logprobs). + first_call = mock_training_client.forward_backward_async.call_args_list[0] + second_call = mock_training_client.forward_backward_async.call_args_list[1] + first_logprobs = first_call.args[0][0].loss_fn_inputs["logprobs"].data + second_logprobs = second_call.args[0][0].loss_fn_inputs["logprobs"].data + assert first_logprobs != second_logprobs + + def test_engine_distill_batch_multiple_samples(tinker_engine, mock_training_client): """Batched distill: 3 samples processed concurrently, single train step.