Skip to content

Fix PEFT base-model contamination and tune SDPO defaults#37

Merged
kfallah merged 2 commits intomainfrom
sdpo-bugfix-and-defaults
Feb 25, 2026
Merged

Fix PEFT base-model contamination and tune SDPO defaults#37
kfallah merged 2 commits intomainfrom
sdpo-bugfix-and-defaults

Conversation

@kfallah
Copy link
Owner

@kfallah kfallah commented Feb 24, 2026

Summary

  • Fix PEFT disable_adapter() bug: PEFT modifies the base model in-place when LoRA is loaded, so teacher scoring and KL base_logprobs were computed WITH the LoRA active — making KL regularization always 0. Now uses model.disable_adapter() context manager to get clean base-model outputs for both teacher scoring and KL computation.
  • Fix ConciseVerifier degenerate response exploit: Responses under 10 words (e.g. ".") passed the ≤3 sentences check because 0 sentences ≤ 3. Added MIN_WORDS=10 guard so degenerate outputs score 0.
  • Update default hyperparameters from experimental sweep (9 experiments on Qwen3-8B / 5090 GPU):
    • learning_rate: 5e-5 → 3e-5
    • kl_reg_weight: 0.001 → 0.0 (teacher GJS distillation provides sufficient policy anchoring)
    • steps_per_batch: 1 → 4

Test plan

  • ruff check passes
  • pytest tests/ -m "not integration" — 110 passed, 25 skipped
  • Run local eval with updated defaults to verify compliance trajectory

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • Configuration

    • Updated default learning rate from 5e-5 to 3e-5 for model training.
    • Disabled KL regularization by default (changed from 0.001 to 0.0).
    • Increased evaluation batch processing to 4 steps per batch for improved efficiency.
  • Improvements

    • Enhanced validation for concise responses with minimum 10-word requirement.
    • Optimized adapter handling in the distillation training flow.

Two bugs fixed:
- PEFT modifies the base model in-place when LoRA is loaded, so
  teacher scoring and KL base_logprobs were computed WITH the LoRA
  active (making KL always 0). Use model.disable_adapter() to get
  clean base-model outputs.
- ConciseVerifier accepted degenerate responses (e.g. ".") because
  0 sentences <= 3. Add MIN_WORDS=10 guard.

Default hyperparameters updated from experimental sweep (Exp 1-9):
- learning_rate: 5e-5 -> 3e-5
- kl_reg_weight: 0.001 -> 0.0 (teacher GJS provides sufficient anchoring)
- steps_per_batch: 1 -> 4

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@coderabbitai
Copy link

coderabbitai bot commented Feb 24, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

The changes update hyperparameter defaults: learning rate from 5e-5 to 3e-5, KL regularization weight from 0.001 to 0.0, and steps_per_batch from 1 to 4 across configuration files. ConciseVerifier enforces a minimum 10-word substantive length constraint. Distillation training now optionally disables PEFT adapters during teacher logit construction and base model inference.

Changes

Cohort / File(s) Summary
Hyperparameter Defaults
claas/core/types.py
TrainingConfig.learning_rate reduced from 5e-5 to 3e-5 and kl_reg_weight set to 0.0; SDPOLossInput.kl_reg_weight set to 0.0 to disable KL regularization by default.
Evaluation Configuration
claas/eval/configs/base.yaml, claas/eval/types.py
steps_per_batch increased from 1 to 4, changing the per-batch processing granularity in evaluation.
Concise Response Verification
claas/eval/metrics/verifiers.py
Added MIN_WORDS = 10 constraint to ConciseVerifier; responses with fewer than 10 words now fail immediately before sentence-count checks; docstring added describing concise-but-substantive criteria.
PEFT Adapter Context Management
claas/training/distillation.py
_build_self_teacher_topk now accepts optional peft_model parameter and disables adapters during inference using contextlib.nullcontext fallback; adapter-disabling context applied during base-model logit computation.
Test Updates
tests/test_eval_config.py, tests/test_eval_runner.py
Concise verifier tests updated to enforce ≥10 words and ≤3 sentences; degenerate short-text case added; steps_per_batch expectations updated from 1 to 4.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~15 minutes

Poem

🐰 The defaults dance, the learning hops lower,
KL weights fade, evaluations grow slower,
Words now matter—ten at the least!
Adapters sleep while teachers are eased,
Changes cascade, but logic's still clean! 🌿

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically summarizes the two main changes: fixing PEFT base-model contamination and tuning SDPO defaults, which aligns with the PR objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch sdpo-bugfix-and-defaults

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

SDPOLossInput had alpha, is_clip, and kl_reg_weight fields that
duplicated TrainingConfig defaults. Replace with a single `training`
field that accepts the config directly, eliminating the duplication.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
claas/training/distillation.py (1)

165-208: ⚠️ Potential issue | 🟡 Minor

Correct _load_or_create_lora return type annotation.

The function is annotated as returning "PeftModel | PeftMixedModel" (line 97), but both code paths only return PeftModel: PeftModel.from_pretrained() in the if-branch and get_peft_model() in the else-branch. This overstated return type propagates to the peft_model parameter in _build_self_teacher_topk (line 173), where it is annotated as "PeftModel | PeftMixedModel | None".

Since PeftMixedModel is never actually instantiated in the codebase and the function always returns PeftModel, the return type should be just "PeftModel". This would eliminate the false contract and accurately reflect what the code actually does.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/distillation.py` around lines 165 - 208, The declared return
type of _load_or_create_lora should be narrowed to "PeftModel" because both
branches return PeftModel (via PeftModel.from_pretrained and get_peft_model);
update the function signature accordingly and then update the peft_model
parameter annotation on _build_self_teacher_topk from "PeftModel |
PeftMixedModel | None" to "PeftModel | None" so the type hints reflect actual
returned/accepted types (ensure you reference _load_or_create_lora,
PeftModel.from_pretrained, get_peft_model, and _build_self_teacher_topk when
making the edits).
🧹 Nitpick comments (2)
claas/training/distillation.py (2)

288-293: base_logprobs forward pass runs even when kl_reg_weight == 0.0 (new default).

With the default kl_reg_weight=0.0, the KL term is zeroed out in compute_sdpo_loss, so base_logprobs contribute nothing to the loss. Yet the full base-model forward pass still executes on every sample — compounded by steps_per_batch=4. Consider guarding it:

⚡ Optional: skip base forward pass when KL weight is zero
-            with torch.no_grad(), model.disable_adapter():
-                base_output = self.base_model(input_ids=full_ids)
-                base_logits = base_output.logits[:, response_start - 1 : -1, :]
-                base_logprobs = self.functional.log_softmax(base_logits, dim=-1).gather(
-                    -1, response_ids[:, :response_token_count].unsqueeze(-1)
-                ).squeeze(-1)
-
-            del base_output, base_logits
-            torch.cuda.empty_cache()
+            if config.kl_reg_weight > 0.0:
+                with torch.no_grad(), model.disable_adapter():
+                    base_output = self.base_model(input_ids=full_ids)
+                    base_logits = base_output.logits[:, response_start - 1 : -1, :]
+                    base_logprobs = self.functional.log_softmax(base_logits, dim=-1).gather(
+                        -1, response_ids[:, :response_token_count].unsqueeze(-1)
+                    ).squeeze(-1)
+                del base_output, base_logits
+                torch.cuda.empty_cache()
+            else:
+                base_logprobs = torch.zeros(1, response_token_count, device=self.device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/distillation.py` around lines 288 - 293, The base-model
forward and computation of base_logprobs are executed even when kl_reg_weight is
0.0; guard that work by checking the KL weight before running the base forward
pass. Modify the block that calls base_model and computes
base_logits/base_logprobs (symbols: base_model, base_output, base_logits,
base_logprobs, model.disable_adapter) so it only runs when self.kl_reg_weight
(or the local kl_reg_weight used by compute_sdpo_loss) > 0.0; otherwise set
base_logprobs to an appropriate zero/tensor placeholder so downstream
compute_sdpo_loss can proceed without the extra forward. Ensure the conditional
preserves shapes and device dtype expectations used later (e.g., when gathering
or computing losses) and does not change behavior when kl_reg_weight > 0.0.

176-189: Docstring is missing the new peft_model parameter (and the pre-existing system_prompt).

📝 Suggested docstring addition
         Args:
             prompt: User prompt.
             feedback: Critique text.
             response_ids: Tokenized sampled response.
             top_k: Number of logits to retain per token.
+            system_prompt: System prompt passed to teacher prompt construction.
+            peft_model: When provided, the teacher forward pass runs under
+                ``peft_model.disable_adapter()`` so LoRA weights are excluded
+                from the base-model scoring.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/distillation.py` around lines 176 - 189, The docstring in
claas/training/distillation.py that currently lists Args (prompt, feedback,
response_ids, top_k) is missing the new peft_model parameter and the existing
system_prompt; update that function's docstring to add entries for system_prompt
(type: str, short description: system-level prompt prepended to user prompt) and
peft_model (type: any/model object, short description: the PEFT/tuned model used
for scoring or inference), include expected types and how each affects outputs,
and keep the Returns and Reference sections unchanged; place these two new param
lines under the Args section alongside the other parameters so documentation and
signatures match.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@claas/training/distillation.py`:
- Around line 165-208: The declared return type of _load_or_create_lora should
be narrowed to "PeftModel" because both branches return PeftModel (via
PeftModel.from_pretrained and get_peft_model); update the function signature
accordingly and then update the peft_model parameter annotation on
_build_self_teacher_topk from "PeftModel | PeftMixedModel | None" to "PeftModel
| None" so the type hints reflect actual returned/accepted types (ensure you
reference _load_or_create_lora, PeftModel.from_pretrained, get_peft_model, and
_build_self_teacher_topk when making the edits).

---

Nitpick comments:
In `@claas/training/distillation.py`:
- Around line 288-293: The base-model forward and computation of base_logprobs
are executed even when kl_reg_weight is 0.0; guard that work by checking the KL
weight before running the base forward pass. Modify the block that calls
base_model and computes base_logits/base_logprobs (symbols: base_model,
base_output, base_logits, base_logprobs, model.disable_adapter) so it only runs
when self.kl_reg_weight (or the local kl_reg_weight used by compute_sdpo_loss) >
0.0; otherwise set base_logprobs to an appropriate zero/tensor placeholder so
downstream compute_sdpo_loss can proceed without the extra forward. Ensure the
conditional preserves shapes and device dtype expectations used later (e.g.,
when gathering or computing losses) and does not change behavior when
kl_reg_weight > 0.0.
- Around line 176-189: The docstring in claas/training/distillation.py that
currently lists Args (prompt, feedback, response_ids, top_k) is missing the new
peft_model parameter and the existing system_prompt; update that function's
docstring to add entries for system_prompt (type: str, short description:
system-level prompt prepended to user prompt) and peft_model (type: any/model
object, short description: the PEFT/tuned model used for scoring or inference),
include expected types and how each affects outputs, and keep the Returns and
Reference sections unchanged; place these two new param lines under the Args
section alongside the other parameters so documentation and signatures match.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between 600a31d and 3a3dca5.

📒 Files selected for processing (7)
  • claas/core/types.py
  • claas/eval/configs/base.yaml
  • claas/eval/metrics/verifiers.py
  • claas/eval/types.py
  • claas/training/distillation.py
  • tests/test_eval_config.py
  • tests/test_eval_runner.py

@kfallah kfallah merged commit 8745ad6 into main Feb 25, 2026
3 checks passed
@kfallah kfallah deleted the sdpo-bugfix-and-defaults branch February 25, 2026 02:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant