Fix PEFT base-model contamination and tune SDPO defaults#37
Conversation
Two bugs fixed: - PEFT modifies the base model in-place when LoRA is loaded, so teacher scoring and KL base_logprobs were computed WITH the LoRA active (making KL always 0). Use model.disable_adapter() to get clean base-model outputs. - ConciseVerifier accepted degenerate responses (e.g. ".") because 0 sentences <= 3. Add MIN_WORDS=10 guard. Default hyperparameters updated from experimental sweep (Exp 1-9): - learning_rate: 5e-5 -> 3e-5 - kl_reg_weight: 0.001 -> 0.0 (teacher GJS provides sufficient anchoring) - steps_per_batch: 1 -> 4 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThe changes update hyperparameter defaults: learning rate from 5e-5 to 3e-5, KL regularization weight from 0.001 to 0.0, and steps_per_batch from 1 to 4 across configuration files. ConciseVerifier enforces a minimum 10-word substantive length constraint. Distillation training now optionally disables PEFT adapters during teacher logit construction and base model inference. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~15 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
SDPOLossInput had alpha, is_clip, and kl_reg_weight fields that duplicated TrainingConfig defaults. Replace with a single `training` field that accepts the config directly, eliminating the duplication. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
claas/training/distillation.py (1)
165-208:⚠️ Potential issue | 🟡 MinorCorrect
_load_or_create_lorareturn type annotation.The function is annotated as returning
"PeftModel | PeftMixedModel"(line 97), but both code paths only returnPeftModel:PeftModel.from_pretrained()in the if-branch andget_peft_model()in the else-branch. This overstated return type propagates to thepeft_modelparameter in_build_self_teacher_topk(line 173), where it is annotated as"PeftModel | PeftMixedModel | None".Since
PeftMixedModelis never actually instantiated in the codebase and the function always returnsPeftModel, the return type should be just"PeftModel". This would eliminate the false contract and accurately reflect what the code actually does.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/distillation.py` around lines 165 - 208, The declared return type of _load_or_create_lora should be narrowed to "PeftModel" because both branches return PeftModel (via PeftModel.from_pretrained and get_peft_model); update the function signature accordingly and then update the peft_model parameter annotation on _build_self_teacher_topk from "PeftModel | PeftMixedModel | None" to "PeftModel | None" so the type hints reflect actual returned/accepted types (ensure you reference _load_or_create_lora, PeftModel.from_pretrained, get_peft_model, and _build_self_teacher_topk when making the edits).
🧹 Nitpick comments (2)
claas/training/distillation.py (2)
288-293:base_logprobsforward pass runs even whenkl_reg_weight == 0.0(new default).With the default
kl_reg_weight=0.0, the KL term is zeroed out incompute_sdpo_loss, sobase_logprobscontribute nothing to the loss. Yet the full base-model forward pass still executes on every sample — compounded bysteps_per_batch=4. Consider guarding it:⚡ Optional: skip base forward pass when KL weight is zero
- with torch.no_grad(), model.disable_adapter(): - base_output = self.base_model(input_ids=full_ids) - base_logits = base_output.logits[:, response_start - 1 : -1, :] - base_logprobs = self.functional.log_softmax(base_logits, dim=-1).gather( - -1, response_ids[:, :response_token_count].unsqueeze(-1) - ).squeeze(-1) - - del base_output, base_logits - torch.cuda.empty_cache() + if config.kl_reg_weight > 0.0: + with torch.no_grad(), model.disable_adapter(): + base_output = self.base_model(input_ids=full_ids) + base_logits = base_output.logits[:, response_start - 1 : -1, :] + base_logprobs = self.functional.log_softmax(base_logits, dim=-1).gather( + -1, response_ids[:, :response_token_count].unsqueeze(-1) + ).squeeze(-1) + del base_output, base_logits + torch.cuda.empty_cache() + else: + base_logprobs = torch.zeros(1, response_token_count, device=self.device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/distillation.py` around lines 288 - 293, The base-model forward and computation of base_logprobs are executed even when kl_reg_weight is 0.0; guard that work by checking the KL weight before running the base forward pass. Modify the block that calls base_model and computes base_logits/base_logprobs (symbols: base_model, base_output, base_logits, base_logprobs, model.disable_adapter) so it only runs when self.kl_reg_weight (or the local kl_reg_weight used by compute_sdpo_loss) > 0.0; otherwise set base_logprobs to an appropriate zero/tensor placeholder so downstream compute_sdpo_loss can proceed without the extra forward. Ensure the conditional preserves shapes and device dtype expectations used later (e.g., when gathering or computing losses) and does not change behavior when kl_reg_weight > 0.0.
176-189: Docstring is missing the newpeft_modelparameter (and the pre-existingsystem_prompt).📝 Suggested docstring addition
Args: prompt: User prompt. feedback: Critique text. response_ids: Tokenized sampled response. top_k: Number of logits to retain per token. + system_prompt: System prompt passed to teacher prompt construction. + peft_model: When provided, the teacher forward pass runs under + ``peft_model.disable_adapter()`` so LoRA weights are excluded + from the base-model scoring.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/distillation.py` around lines 176 - 189, The docstring in claas/training/distillation.py that currently lists Args (prompt, feedback, response_ids, top_k) is missing the new peft_model parameter and the existing system_prompt; update that function's docstring to add entries for system_prompt (type: str, short description: system-level prompt prepended to user prompt) and peft_model (type: any/model object, short description: the PEFT/tuned model used for scoring or inference), include expected types and how each affects outputs, and keep the Returns and Reference sections unchanged; place these two new param lines under the Args section alongside the other parameters so documentation and signatures match.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@claas/training/distillation.py`:
- Around line 165-208: The declared return type of _load_or_create_lora should
be narrowed to "PeftModel" because both branches return PeftModel (via
PeftModel.from_pretrained and get_peft_model); update the function signature
accordingly and then update the peft_model parameter annotation on
_build_self_teacher_topk from "PeftModel | PeftMixedModel | None" to "PeftModel
| None" so the type hints reflect actual returned/accepted types (ensure you
reference _load_or_create_lora, PeftModel.from_pretrained, get_peft_model, and
_build_self_teacher_topk when making the edits).
---
Nitpick comments:
In `@claas/training/distillation.py`:
- Around line 288-293: The base-model forward and computation of base_logprobs
are executed even when kl_reg_weight is 0.0; guard that work by checking the KL
weight before running the base forward pass. Modify the block that calls
base_model and computes base_logits/base_logprobs (symbols: base_model,
base_output, base_logits, base_logprobs, model.disable_adapter) so it only runs
when self.kl_reg_weight (or the local kl_reg_weight used by compute_sdpo_loss) >
0.0; otherwise set base_logprobs to an appropriate zero/tensor placeholder so
downstream compute_sdpo_loss can proceed without the extra forward. Ensure the
conditional preserves shapes and device dtype expectations used later (e.g.,
when gathering or computing losses) and does not change behavior when
kl_reg_weight > 0.0.
- Around line 176-189: The docstring in claas/training/distillation.py that
currently lists Args (prompt, feedback, response_ids, top_k) is missing the new
peft_model parameter and the existing system_prompt; update that function's
docstring to add entries for system_prompt (type: str, short description:
system-level prompt prepended to user prompt) and peft_model (type: any/model
object, short description: the PEFT/tuned model used for scoring or inference),
include expected types and how each affects outputs, and keep the Returns and
Reference sections unchanged; place these two new param lines under the Args
section alongside the other parameters so documentation and signatures match.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to data retention organization setting
📒 Files selected for processing (7)
claas/core/types.pyclaas/eval/configs/base.yamlclaas/eval/metrics/verifiers.pyclaas/eval/types.pyclaas/training/distillation.pytests/test_eval_config.pytests/test_eval_runner.py
Summary
disable_adapter()bug: PEFT modifies the base model in-place when LoRA is loaded, so teacher scoring and KLbase_logprobswere computed WITH the LoRA active — making KL regularization always 0. Now usesmodel.disable_adapter()context manager to get clean base-model outputs for both teacher scoring and KL computation.".") passed the≤3 sentencescheck because 0 sentences ≤ 3. AddedMIN_WORDS=10guard so degenerate outputs score 0.learning_rate: 5e-5 → 3e-5kl_reg_weight: 0.001 → 0.0 (teacher GJS distillation provides sufficient policy anchoring)steps_per_batch: 1 → 4Test plan
ruff checkpassespytest tests/ -m "not integration"— 110 passed, 25 skipped🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
Configuration
Improvements