Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions claas/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class TrainingConfig(BaseModel):
"""Training configuration for distillation."""

learning_rate: float = Field(
default=5e-5,
default=3e-5,
description="Learning rate for LoRA parameter updates",
)
alpha: float = Field(
Expand All @@ -46,7 +46,7 @@ class TrainingConfig(BaseModel):
description="Maximum gradient norm for clipping",
)
kl_reg_weight: float = Field(
default=0.001,
default=0.0,
ge=0.0,
le=1.0,
description="Weight for KL regularization to base policy",
Expand All @@ -71,9 +71,7 @@ class SDPOLossInput(BaseModel):
response_mask: Any # torch.Tensor (B, T)
old_student_logprobs: Any # torch.Tensor (B, T)
response_ids: Any # torch.Tensor (B, T)
alpha: float = 0.5
is_clip: float = 5.0
kl_reg_weight: float = 0.001
training: TrainingConfig = Field(default_factory=TrainingConfig)


class SDPOLossResult(TypedDict):
Expand Down
2 changes: 1 addition & 1 deletion claas/eval/configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ plots: true

num_steps: 20
batch_size: 4
steps_per_batch: 1
steps_per_batch: 4
feedback_repetitions: 1
seed: 42
lora_id_prefix: eval
Expand Down
7 changes: 7 additions & 0 deletions claas/eval/metrics/verifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,14 @@ def __call__(self, response: str) -> VerifierResult:


class ConciseVerifier:
"""Pass when response is concise (<=3 sentences) but still substantive (>=10 words)."""

MIN_WORDS = 10

def __call__(self, response: str) -> VerifierResult:
word_count = len(response.split())
if word_count < self.MIN_WORDS:
return VerifierResult(score=0.0, passed=False)
n = _count_sentences(response)
if n <= 3:
score = 1.0
Expand Down
2 changes: 1 addition & 1 deletion claas/eval/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class EvalConfig:
openclaw_url: Optional[str] = None
base_model: str = "Qwen/Qwen3-8B"
batch_size: int = 4
steps_per_batch: int = 1
steps_per_batch: int = 4
feedback_repetitions: int = 1


Expand Down
12 changes: 7 additions & 5 deletions claas/training/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import os
import tempfile
from contextlib import nullcontext as _nullcontext
from typing import TYPE_CHECKING, cast

import torch
Expand Down Expand Up @@ -169,6 +170,7 @@ def _build_self_teacher_topk(
top_k: int,
*,
system_prompt: str,
peft_model: "PeftModel | PeftMixedModel | None" = None,
) -> tuple["torch.Tensor", "torch.Tensor", str]:
"""Build top-k teacher logits from the frozen base model.

Expand Down Expand Up @@ -202,7 +204,8 @@ def _build_self_teacher_topk(
teacher_resp_start = teacher_prompt_ids.shape[-1]
response_token_count = response_ids.shape[-1]

with torch.no_grad():
ctx = peft_model.disable_adapter() if peft_model is not None else _nullcontext()
with torch.no_grad(), ctx:
teacher_output = self.base_model(input_ids=teacher_full_ids)
teacher_logits = teacher_output.logits[:, teacher_resp_start - 1 : -1, :]
log_probs = self.functional.log_softmax(teacher_logits, dim=-1)
Expand Down Expand Up @@ -282,7 +285,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
response_mask = torch.zeros(1, full_ids.shape[-1], device=self.device)
response_mask[:, response_start:] = 1.0

with torch.no_grad():
with torch.no_grad(), model.disable_adapter():
base_output = self.base_model(input_ids=full_ids)
base_logits = base_output.logits[:, response_start - 1 : -1, :]
base_logprobs = self.functional.log_softmax(base_logits, dim=-1).gather(
Expand Down Expand Up @@ -312,6 +315,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
response_ids,
config.teacher_top_k,
system_prompt=sample.system_prompt,
peft_model=model,
)

if teacher_logprobs.shape[0] != response_token_count:
Expand All @@ -325,9 +329,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
response_mask=response_mask[:, response_start:],
old_student_logprobs=old_student_logprobs,
response_ids=response_ids[:, :response_token_count],
alpha=config.alpha,
is_clip=config.is_clip,
kl_reg_weight=config.kl_reg_weight,
training=config,
)
loss_dict = compute_sdpo_loss(loss_input)
batch_loss_tensors.append(loss_dict["loss"])
Expand Down
6 changes: 3 additions & 3 deletions claas/training/sdpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def compute_sdpo_loss(loss_input: SDPOLossInput) -> SDPOLossResult:
response_mask = loss_input.response_mask
old_student_logprobs = loss_input.old_student_logprobs
response_ids = loss_input.response_ids
alpha = loss_input.alpha
is_clip = loss_input.is_clip
kl_reg_weight = loss_input.kl_reg_weight
alpha = loss_input.training.alpha
is_clip = loss_input.training.is_clip
kl_reg_weight = loss_input.training.kl_reg_weight

_B, _T, _V = student_logits.shape

Expand Down
15 changes: 12 additions & 3 deletions tests/test_eval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,23 @@ def test_preference_verifier_callable() -> None:
assert result.score == 0.0
assert result.passed is False

# concise: short text should pass
result = configs["concise"].verifier("One sentence. Two sentences. Three.")
# concise: short text should pass (>=10 words, <=3 sentences)
result = configs["concise"].verifier(
"Python is a versatile, high-level programming language known for its readable syntax."
)
assert result.score == 1.0
assert result.passed is True

# concise: degenerate text (too few words) should fail
result = configs["concise"].verifier("Just a dot.")
assert result.score == 0.0
assert result.passed is False

# concise: verbose text should fail
result = configs["concise"].verifier(
"One. Two. Three. Four. Five. Six. Seven. Eight. Nine. Ten."
"One thing to know. Two things to know. Three things to know. "
"Four things to know. Five things to know. Six things to know. "
"Seven things to know. Eight things to know. Nine things to know. Ten things."
)
assert result.score < 1.0
assert result.passed is False
Expand Down
4 changes: 2 additions & 2 deletions tests/test_eval_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def test_step_result_from_dict_no_metrics():


def test_steps_per_batch_default():
"""HarnessConfig.steps_per_batch defaults to 1."""
"""HarnessConfig.steps_per_batch defaults to 4."""
config = HarnessConfig()
assert config.steps_per_batch == 1
assert config.steps_per_batch == 4


def test_steps_per_batch_custom():
Expand Down
15 changes: 13 additions & 2 deletions tests/test_sdpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

torch = pytest.importorskip("torch")

from claas.core.types import SDPOLossInput # noqa: E402
from claas.core.types import SDPOLossInput, TrainingConfig # noqa: E402
from claas.training.sdpo_loss import _lookup_token_in_topk, compute_sdpo_loss # noqa: E402


Expand Down Expand Up @@ -104,8 +104,18 @@ def test_batch_processing(self, device):
class TestComputeSdpoLoss:
"""Tests for compute_sdpo_loss."""

_TRAINING_FIELDS = {"alpha", "is_clip", "kl_reg_weight"}

@staticmethod
def _run_loss(sample_data: dict, **overrides):
training_overrides = {
k: v for k, v in overrides.items()
if k in TestComputeSdpoLoss._TRAINING_FIELDS
}
tensor_overrides = {
k: v for k, v in overrides.items()
if k not in TestComputeSdpoLoss._TRAINING_FIELDS
}
payload = {
"student_logits": sample_data["student_logits"],
"teacher_logprobs": sample_data["teacher_logprobs"],
Expand All @@ -114,8 +124,9 @@ def _run_loss(sample_data: dict, **overrides):
"response_mask": sample_data["response_mask"],
"old_student_logprobs": sample_data["old_student_logprobs"],
"response_ids": sample_data["response_ids"],
"training": TrainingConfig(**training_overrides),
}
payload.update(overrides)
payload.update(tensor_overrides)
return compute_sdpo_loss(SDPOLossInput(**payload))

def test_returns_expected_keys(self, sample_data):
Expand Down