Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,13 @@ The eval dashboard at `/v1/eval` displays results from running the eval harness

## References

1. Hübotter et al. (2026). "Reinforcement Learning via Self-Distillation." arXiv:2601.20802
2. SDPO Reference Implementation: https://github.com/lasgroup/SDPO
3. Modal GPU Memory Snapshots: https://modal.com/blog/gpu-mem-snapshots
4. vLLM: https://github.com/vllm-project/vllm
5. PEFT/LoRA: https://github.com/huggingface/peft
6. Tinker SDPO training reference (continualcode): https://github.com/sdan/continualcode
1. Kleine Buening et al. (2026). "Aligning Language Models from User Interactions." [Paper](https://self-distillation.github.io/user_interactions.pdf) · [Code](https://github.com/lasgroup/user_interactions) · [Hindsight template](https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py)
2. Hübotter et al. (2026). "Reinforcement Learning via Self-Distillation." arXiv:2601.20802
3. SDPO Reference Implementation: https://github.com/lasgroup/SDPO
4. Modal GPU Memory Snapshots: https://modal.com/blog/gpu-mem-snapshots
5. vLLM: https://github.com/vllm-project/vllm
6. PEFT/LoRA: https://github.com/huggingface/peft
7. Tinker SDPO training reference (continualcode): https://github.com/sdan/continualcode

## License

Expand Down
13 changes: 13 additions & 0 deletions claas/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ async def chat_completions(
{"role": m.role, "content": coerce_content(m.content)}
for m in req.messages
]
system_parts = [m["content"] for m in messages if m["role"] == "system"]
system_prompt = "\n".join(system_parts) if system_parts else ""
result = await backend.chat_completion(
messages=messages,
model=req.model or "default",
Expand Down Expand Up @@ -259,6 +261,7 @@ async def chat_completions(
response_token_ids=result.response_token_ids,
prompt_token_ids=result.prompt_token_ids,
response_logprobs=result.response_logprobs,
system_prompt=system_prompt,
),
)

Expand Down Expand Up @@ -394,6 +397,15 @@ async def feedback(request: FeedbackBatchRequest) -> FeedbackResponse:
status_code=422,
detail=f"Cached completion has no logprobs (content_hash={content_hash[:16]}…)",
)
if not entry.system_prompt:
raise HTTPException(
status_code=422,
detail=(
f"Cached completion has no system prompt (content_hash={content_hash[:16]}…). "
"The chat completion request must include a system message so the teacher "
"can score under the same context as the student."
),
)
batch_samples.append(
DistillBatchItem(
prompt=entry.prompt,
Expand All @@ -403,6 +415,7 @@ async def feedback(request: FeedbackBatchRequest) -> FeedbackResponse:
prompt_token_ids=entry.prompt_token_ids,
response_token_ids=entry.response_token_ids,
user_prompt=req.prompt,
system_prompt=entry.system_prompt,
)
)

Expand Down
6 changes: 6 additions & 0 deletions claas/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ class DistillBatchItem(BaseModel):
"not a nested chat template."
),
)
system_prompt: str = Field(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep DistillBatchItem compatible with older feedback logs

Making system_prompt required causes previously written feedback log files (which do not have this field in batch_samples) to fail FeedbackLogRecord.model_validate(...); read_recent_feedback_logs then silently skips those files, so /v1/dashboard can lose historical records immediately after deploy. Adding a default/optional value here (or an explicit migration path) would preserve dashboard visibility for pre-existing logs.

Useful? React with 👍 / 👎.

description=(
"System prompt from the chat completion request. "
"Passed to the teacher so it scores under the same context as the student."
),
)


class DistillBatchRequestPayload(BaseModel):
Expand Down
10 changes: 10 additions & 0 deletions claas/dashboard/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,15 @@ def feedback_dashboard_rows(records: list[FeedbackLogRecord]) -> str:

# -- Expandable detail row --
sample_sections: list[str] = []
_raw_texts = metrics_payload.get("teacher_scored_texts") if metrics_payload else None
teacher_scored_texts = list(_raw_texts) if isinstance(_raw_texts, list) else []
for item_index, sample in enumerate(record.batch_samples):
teacher_section = ""
if item_index < len(teacher_scored_texts):
teacher_section = (
"<section><h3>Teacher Scored Text</h3>"
"<pre>{teacher_text}</pre></section>"
).format(teacher_text=html.escape(str(teacher_scored_texts[item_index])))
sample_sections.append(
"""
<details{open_attr}>
Expand All @@ -86,6 +94,7 @@ def feedback_dashboard_rows(records: list[FeedbackLogRecord]) -> str:
<section><h3>Prompt</h3><pre>{prompt}</pre></section>
<section><h3>Response</h3><pre>{response}</pre></section>
<section><h3>Feedback</h3><pre>{feedback}</pre></section>
{teacher_section}
</div>
</details>
""".format(
Expand All @@ -96,6 +105,7 @@ def feedback_dashboard_rows(records: list[FeedbackLogRecord]) -> str:
prompt=html.escape(sample.prompt),
response=html.escape(sample.response),
feedback=html.escape(sample.feedback),
teacher_section=teacher_section,
)
)

Expand Down
3 changes: 2 additions & 1 deletion claas/eval/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@

async def _init_lora(config: HarnessConfig, lora_id: str) -> str:
"""Initialize a fresh LoRA adapter via CLaaS API."""
async with httpx.AsyncClient(base_url=config.claas_url, timeout=120.0) as client:
# LoRA init can exceed two minutes when the remote trainer is cold-starting.
async with httpx.AsyncClient(base_url=config.claas_url, timeout=300.0) as client:
resp = await client.post(
"/v1/lora/init",
json={"lora_id": lora_id, "base_model": config.base_model},
Expand Down
4 changes: 3 additions & 1 deletion claas/inference/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class CompletionCacheEntry:
"""A single cached completion with prompt, response, token IDs, and logprobs."""

__slots__ = ("prompt", "response", "response_token_ids", "prompt_token_ids", "response_logprobs", "created_at")
__slots__ = ("prompt", "response", "response_token_ids", "prompt_token_ids", "response_logprobs", "system_prompt", "created_at")

def __init__(
self,
Expand All @@ -21,12 +21,14 @@ def __init__(
response_token_ids: list[int],
prompt_token_ids: list[int],
response_logprobs: list[float] | None,
system_prompt: str,
) -> None:
self.prompt = prompt
self.response = response
self.response_token_ids = response_token_ids
self.prompt_token_ids = prompt_token_ids
self.response_logprobs = response_logprobs
self.system_prompt = system_prompt
self.created_at = time.monotonic()

def is_expired(self) -> bool:
Expand Down
23 changes: 15 additions & 8 deletions claas/training/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def _build_self_teacher_topk(
feedback: str,
response_ids: "torch.Tensor",
top_k: int,
) -> tuple["torch.Tensor", "torch.Tensor"]:
*,
system_prompt: str,
) -> tuple["torch.Tensor", "torch.Tensor", str]:
"""Build top-k teacher logits from the frozen base model.

Args:
Expand All @@ -177,18 +179,18 @@ def _build_self_teacher_topk(
top_k: Number of logits to retain per token.

Returns:
Pair of top-k logprobs and indices for each response token.
Triple of (top-k logprobs, top-k indices, teacher_scored_text).

Reference:
Huebotter et al. (2026), "Reinforcement Learning via Self-Distillation"
(arXiv:2601.20802), Section 3.
Kleine Buening et al. (2026), "Aligning Language Models from User Interactions"
https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py
"""

messages = build_teacher_messages(prompt, feedback)
messages = build_teacher_messages(prompt, feedback, system_prompt=system_prompt)
template_messages = teacher_messages_to_chat_template(messages)
teacher_prompt_ids_raw = self.tokenizer.apply_chat_template(
template_messages,
add_generation_prompt=False,
add_generation_prompt=True,
return_tensors="pt",
tokenize=True,
)
Expand All @@ -209,9 +211,10 @@ def _build_self_teacher_topk(
top_logprobs, top_indices = torch.topk(log_probs[0, :response_token_count], k=k, dim=-1)

del teacher_output, teacher_logits, log_probs
teacher_scored_text = self.tokenizer.decode(teacher_full_ids[0].tolist(), skip_special_tokens=False)
del teacher_full_ids, teacher_prompt_ids
torch.cuda.empty_cache()
return top_logprobs, top_indices
return top_logprobs, top_indices, teacher_scored_text

def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
"""Run one SDPO distillation step.
Expand Down Expand Up @@ -256,6 +259,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
batch_kl_reg: list[float] = []
batch_mean_is_ratio: list[float] = []
batch_clip_fraction: list[float] = []
batch_teacher_scored_texts: list[str] = []
tokens_processed = 0

for sample in payload.samples:
Expand Down Expand Up @@ -302,11 +306,12 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
elif old_student_logprobs.shape[1] < response_token_count:
raise ValueError("response_logprobs length must match response token length")

teacher_logprobs, teacher_indices = self._build_self_teacher_topk(
teacher_logprobs, teacher_indices, teacher_scored_text = self._build_self_teacher_topk(
sample.user_prompt,
sample.feedback,
response_ids,
config.teacher_top_k,
system_prompt=sample.system_prompt,
)

if teacher_logprobs.shape[0] != response_token_count:
Expand All @@ -330,6 +335,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
batch_kl_reg.append(loss_dict["kl_reg"])
batch_mean_is_ratio.append(loss_dict["mean_is_ratio"])
batch_clip_fraction.append(loss_dict["clip_fraction"])
batch_teacher_scored_texts.append(teacher_scored_text)

del full_ids, prompt_ids, response_ids, response_mask
del student_logits, base_logprobs, old_student_logprobs
Expand Down Expand Up @@ -377,6 +383,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
"grad_norm": grad_norm_value,
"tokens_processed": tokens_processed,
"batch_size": len(payload.samples),
"teacher_scored_texts": batch_teacher_scored_texts,
},
}
)
Expand Down
13 changes: 9 additions & 4 deletions claas/training/engine/tinker/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ async def distill(
"lr": lr,
"loss_fn": "importance_sampling",
"timestamp": datetime.now(timezone.utc).isoformat(),
"teacher_scored_texts": [m["teacher_scored_text"] for m in sample_metrics],
}

if hasattr(fwd_bwd, "metrics") and fwd_bwd.metrics:
Expand All @@ -265,7 +266,7 @@ async def _build_sample_datum(
tokenizer: Any,
teacher_sampling: Any,
kl_coef: float,
) -> tuple[T.Datum, dict[str, float]]:
) -> tuple[T.Datum, dict[str, float | str]]:
"""Process a single sample into a Datum and per-sample metrics.

This is the per-sample logic extracted from ``distill()`` so that
Expand All @@ -288,11 +289,13 @@ async def _build_sample_datum(

# ── Build teacher prompt (matching local worker: build_teacher_messages) ──
teacher_prompt_source = sample.user_prompt
teacher_messages = build_teacher_messages(teacher_prompt_source, sample.feedback)
teacher_messages = build_teacher_messages(
teacher_prompt_source, sample.feedback, system_prompt=sample.system_prompt
)
template_messages = teacher_messages_to_chat_template(teacher_messages)
teacher_prompt_text = tokenizer.apply_chat_template(
template_messages,
add_generation_prompt=False,
add_generation_prompt=True,
tokenize=False,
)
teacher_prompt_tokens: list[int] = tokenizer.encode(
Expand All @@ -302,6 +305,7 @@ async def _build_sample_datum(
teacher_full_tokens = teacher_prompt_tokens + response_tokens
teacher_prompt_len = len(teacher_prompt_tokens)
teacher_full = T.ModelInput.from_ints(teacher_full_tokens)
teacher_scored_text = tokenizer.decode(teacher_full_tokens, skip_special_tokens=False)

# ── Compute teacher logprobs (base model = self-distillation) ──
teacher_logprobs_full = await teacher_sampling.compute_logprobs_async(teacher_full)
Expand Down Expand Up @@ -346,14 +350,15 @@ async def _build_sample_datum(
adv_abs_mean = sum(abs(a) for a in advantages) / max(len(advantages), 1)
kl_mean = sum(raw_kl_deltas) / max(len(raw_kl_deltas), 1)

metrics = {
metrics: dict[str, float | str] = {
"completion_len": completion_len,
"effective_kl_coef": effective_kl_coef,
"kl_gain": gain,
"adv_mean": adv_mean,
"adv_abs_mean": adv_abs_mean,
"kl_mean": kl_mean,
"adv_abs_mean_raw": adv_abs_mean_raw,
"teacher_scored_text": teacher_scored_text,
}

return datum, metrics
Expand Down
36 changes: 22 additions & 14 deletions claas/training/teacher_helpers.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,51 @@
"""Pure helper functions for teacher prompt building.

These functions have no Modal dependency and can be imported anywhere.

Reference:
Kleine Buening et al. (2026), "Aligning Language Models from User Interactions"
https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py
"""

from __future__ import annotations

from claas.core.config import DEFAULT_SYSTEM_PROMPT
from claas.core.types import ChatMessage


def build_teacher_messages(
prompt: str,
feedback: str | None = None,
system_prompt: str | None = None,
*,
system_prompt: str,
) -> list[ChatMessage]:
"""Build chat messages for teacher prompt (veRL-compatible template).
"""Build chat messages for the hindsight policy (SDPO teacher).

Appends a hindsight context block to the user prompt so the frozen base
model can score the student's response tokens with awareness of the user's
follow-up. Template follows the online SDPO trainer from:
https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py

Template matches veRL's reprompt_template structure:
{prompt}{feedback}\\n\\nCorrectly solve the original question.
The system_prompt must match the one the student saw so the teacher scores
under the same context.

Args:
prompt: The original user prompt
feedback: Optional feedback about the response quality
system_prompt: Optional system prompt
feedback: Optional user follow-up / feedback about the response
system_prompt: System prompt (must match the student's)

Returns:
List of message dicts with 'role' and 'content' keys
"""
if system_prompt is None:
system_prompt = DEFAULT_SYSTEM_PROMPT
messages: list[ChatMessage] = [{"role": "system", "content": system_prompt}]

# Build user content with veRL-style template
if feedback:
feedback_section = (
"\n\nThe following is positive or negative feedback from your earlier attempt:"
f"\n\n{feedback}\n"
block = (
"\n\n=== HINDSIGHT CONTEXT ===\n"
"[The following is a future user message. "
"Use this to guide your answer to the user prompt.]\n"
f"{feedback.strip()}"
)
user_content = f"{prompt}{feedback_section}\n\nCorrectly solve the original question.\n"
user_content = f"{prompt}{block}"
else:
user_content = prompt

Expand Down
1 change: 1 addition & 0 deletions docker/scripts/init-stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def write_openclaw_config() -> None:
"port": 18789,
"mode": "local",
"bind": "loopback",
"controlUi": {"enabled": False},
"http": {
"endpoints": {
"chatCompletions": {"enabled": True},
Expand Down
6 changes: 6 additions & 0 deletions tests/integration/run_tinker_stack_integration.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@ compose() {
docker compose -f "$COMPOSE_FILE" --profile tinker "$@"
}

dump_logs() {
echo "==> Docker container logs (last 200 lines each) ..."
compose logs --tail=200 || true
}

cleanup() {
dump_logs
echo "==> Tearing down tinker stack ..."
compose down -v || true
}
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_local_engine_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def test_local_engine_integration_paths(monkeypatch, tmp_path):
prompt_token_ids=[1, 2],
response_token_ids=[3],
user_prompt="prompt",
system_prompt="You are a helpful assistant.",
)
],
)
Expand Down Expand Up @@ -158,6 +159,7 @@ def test_local_engine_cleanup_failure_propagates(monkeypatch):
prompt_token_ids=[1, 2],
response_token_ids=[3],
user_prompt="prompt",
system_prompt="You are a helpful assistant.",
)
],
)
Expand Down
Loading