diff --git a/claas/core/types.py b/claas/core/types.py index 419265f..6a81f8f 100644 --- a/claas/core/types.py +++ b/claas/core/types.py @@ -25,7 +25,7 @@ class TrainingConfig(BaseModel): """Training configuration for distillation.""" learning_rate: float = Field( - default=5e-5, + default=3e-5, description="Learning rate for LoRA parameter updates", ) alpha: float = Field( @@ -46,7 +46,7 @@ class TrainingConfig(BaseModel): description="Maximum gradient norm for clipping", ) kl_reg_weight: float = Field( - default=0.001, + default=0.0, ge=0.0, le=1.0, description="Weight for KL regularization to base policy", @@ -71,9 +71,7 @@ class SDPOLossInput(BaseModel): response_mask: Any # torch.Tensor (B, T) old_student_logprobs: Any # torch.Tensor (B, T) response_ids: Any # torch.Tensor (B, T) - alpha: float = 0.5 - is_clip: float = 5.0 - kl_reg_weight: float = 0.001 + training: TrainingConfig = Field(default_factory=TrainingConfig) class SDPOLossResult(TypedDict): diff --git a/claas/eval/configs/base.yaml b/claas/eval/configs/base.yaml index a378ef0..92323c1 100644 --- a/claas/eval/configs/base.yaml +++ b/claas/eval/configs/base.yaml @@ -22,7 +22,7 @@ plots: true num_steps: 20 batch_size: 4 -steps_per_batch: 1 +steps_per_batch: 4 feedback_repetitions: 1 seed: 42 lora_id_prefix: eval diff --git a/claas/eval/metrics/verifiers.py b/claas/eval/metrics/verifiers.py index f1ab6d9..9a579ed 100644 --- a/claas/eval/metrics/verifiers.py +++ b/claas/eval/metrics/verifiers.py @@ -76,7 +76,14 @@ def __call__(self, response: str) -> VerifierResult: class ConciseVerifier: + """Pass when response is concise (<=3 sentences) but still substantive (>=10 words).""" + + MIN_WORDS = 10 + def __call__(self, response: str) -> VerifierResult: + word_count = len(response.split()) + if word_count < self.MIN_WORDS: + return VerifierResult(score=0.0, passed=False) n = _count_sentences(response) if n <= 3: score = 1.0 diff --git a/claas/eval/types.py b/claas/eval/types.py index a175a6f..84d1733 100644 --- a/claas/eval/types.py +++ b/claas/eval/types.py @@ -94,7 +94,7 @@ class EvalConfig: openclaw_url: Optional[str] = None base_model: str = "Qwen/Qwen3-8B" batch_size: int = 4 - steps_per_batch: int = 1 + steps_per_batch: int = 4 feedback_repetitions: int = 1 diff --git a/claas/training/distillation.py b/claas/training/distillation.py index a46ee89..5487303 100644 --- a/claas/training/distillation.py +++ b/claas/training/distillation.py @@ -6,6 +6,7 @@ import logging import os import tempfile +from contextlib import nullcontext as _nullcontext from typing import TYPE_CHECKING, cast import torch @@ -169,6 +170,7 @@ def _build_self_teacher_topk( top_k: int, *, system_prompt: str, + peft_model: "PeftModel | PeftMixedModel | None" = None, ) -> tuple["torch.Tensor", "torch.Tensor", str]: """Build top-k teacher logits from the frozen base model. @@ -202,7 +204,8 @@ def _build_self_teacher_topk( teacher_resp_start = teacher_prompt_ids.shape[-1] response_token_count = response_ids.shape[-1] - with torch.no_grad(): + ctx = peft_model.disable_adapter() if peft_model is not None else _nullcontext() + with torch.no_grad(), ctx: teacher_output = self.base_model(input_ids=teacher_full_ids) teacher_logits = teacher_output.logits[:, teacher_resp_start - 1 : -1, :] log_probs = self.functional.log_softmax(teacher_logits, dim=-1) @@ -282,7 +285,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: response_mask = torch.zeros(1, full_ids.shape[-1], device=self.device) response_mask[:, response_start:] = 1.0 - with torch.no_grad(): + with torch.no_grad(), model.disable_adapter(): base_output = self.base_model(input_ids=full_ids) base_logits = base_output.logits[:, response_start - 1 : -1, :] base_logprobs = self.functional.log_softmax(base_logits, dim=-1).gather( @@ -312,6 +315,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: response_ids, config.teacher_top_k, system_prompt=sample.system_prompt, + peft_model=model, ) if teacher_logprobs.shape[0] != response_token_count: @@ -325,9 +329,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: response_mask=response_mask[:, response_start:], old_student_logprobs=old_student_logprobs, response_ids=response_ids[:, :response_token_count], - alpha=config.alpha, - is_clip=config.is_clip, - kl_reg_weight=config.kl_reg_weight, + training=config, ) loss_dict = compute_sdpo_loss(loss_input) batch_loss_tensors.append(loss_dict["loss"]) diff --git a/claas/training/sdpo_loss.py b/claas/training/sdpo_loss.py index 2ead21a..a8c00f8 100644 --- a/claas/training/sdpo_loss.py +++ b/claas/training/sdpo_loss.py @@ -46,9 +46,9 @@ def compute_sdpo_loss(loss_input: SDPOLossInput) -> SDPOLossResult: response_mask = loss_input.response_mask old_student_logprobs = loss_input.old_student_logprobs response_ids = loss_input.response_ids - alpha = loss_input.alpha - is_clip = loss_input.is_clip - kl_reg_weight = loss_input.kl_reg_weight + alpha = loss_input.training.alpha + is_clip = loss_input.training.is_clip + kl_reg_weight = loss_input.training.kl_reg_weight _B, _T, _V = student_logits.shape diff --git a/tests/test_eval_config.py b/tests/test_eval_config.py index 802f5e2..2179bee 100644 --- a/tests/test_eval_config.py +++ b/tests/test_eval_config.py @@ -118,14 +118,23 @@ def test_preference_verifier_callable() -> None: assert result.score == 0.0 assert result.passed is False - # concise: short text should pass - result = configs["concise"].verifier("One sentence. Two sentences. Three.") + # concise: short text should pass (>=10 words, <=3 sentences) + result = configs["concise"].verifier( + "Python is a versatile, high-level programming language known for its readable syntax." + ) assert result.score == 1.0 assert result.passed is True + # concise: degenerate text (too few words) should fail + result = configs["concise"].verifier("Just a dot.") + assert result.score == 0.0 + assert result.passed is False + # concise: verbose text should fail result = configs["concise"].verifier( - "One. Two. Three. Four. Five. Six. Seven. Eight. Nine. Ten." + "One thing to know. Two things to know. Three things to know. " + "Four things to know. Five things to know. Six things to know. " + "Seven things to know. Eight things to know. Nine things to know. Ten things." ) assert result.score < 1.0 assert result.passed is False diff --git a/tests/test_eval_runner.py b/tests/test_eval_runner.py index 1a8370e..fa04c0c 100644 --- a/tests/test_eval_runner.py +++ b/tests/test_eval_runner.py @@ -108,9 +108,9 @@ def test_step_result_from_dict_no_metrics(): def test_steps_per_batch_default(): - """HarnessConfig.steps_per_batch defaults to 1.""" + """HarnessConfig.steps_per_batch defaults to 4.""" config = HarnessConfig() - assert config.steps_per_batch == 1 + assert config.steps_per_batch == 4 def test_steps_per_batch_custom(): diff --git a/tests/test_sdpo_loss.py b/tests/test_sdpo_loss.py index b01482b..26b1ff7 100644 --- a/tests/test_sdpo_loss.py +++ b/tests/test_sdpo_loss.py @@ -6,7 +6,7 @@ torch = pytest.importorskip("torch") -from claas.core.types import SDPOLossInput # noqa: E402 +from claas.core.types import SDPOLossInput, TrainingConfig # noqa: E402 from claas.training.sdpo_loss import _lookup_token_in_topk, compute_sdpo_loss # noqa: E402 @@ -104,8 +104,18 @@ def test_batch_processing(self, device): class TestComputeSdpoLoss: """Tests for compute_sdpo_loss.""" + _TRAINING_FIELDS = {"alpha", "is_clip", "kl_reg_weight"} + @staticmethod def _run_loss(sample_data: dict, **overrides): + training_overrides = { + k: v for k, v in overrides.items() + if k in TestComputeSdpoLoss._TRAINING_FIELDS + } + tensor_overrides = { + k: v for k, v in overrides.items() + if k not in TestComputeSdpoLoss._TRAINING_FIELDS + } payload = { "student_logits": sample_data["student_logits"], "teacher_logprobs": sample_data["teacher_logprobs"], @@ -114,8 +124,9 @@ def _run_loss(sample_data: dict, **overrides): "response_mask": sample_data["response_mask"], "old_student_logprobs": sample_data["old_student_logprobs"], "response_ids": sample_data["response_ids"], + "training": TrainingConfig(**training_overrides), } - payload.update(overrides) + payload.update(tensor_overrides) return compute_sdpo_loss(SDPOLossInput(**payload)) def test_returns_expected_keys(self, sample_data):