diff --git a/README.md b/README.md
index 76ae148..1f1bc44 100644
--- a/README.md
+++ b/README.md
@@ -167,12 +167,13 @@ The eval dashboard at `/v1/eval` displays results from running the eval harness
## References
-1. Hübotter et al. (2026). "Reinforcement Learning via Self-Distillation." arXiv:2601.20802
-2. SDPO Reference Implementation: https://github.com/lasgroup/SDPO
-3. Modal GPU Memory Snapshots: https://modal.com/blog/gpu-mem-snapshots
-4. vLLM: https://github.com/vllm-project/vllm
-5. PEFT/LoRA: https://github.com/huggingface/peft
-6. Tinker SDPO training reference (continualcode): https://github.com/sdan/continualcode
+1. Kleine Buening et al. (2026). "Aligning Language Models from User Interactions." [Paper](https://self-distillation.github.io/user_interactions.pdf) · [Code](https://github.com/lasgroup/user_interactions) · [Hindsight template](https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py)
+2. Hübotter et al. (2026). "Reinforcement Learning via Self-Distillation." arXiv:2601.20802
+3. SDPO Reference Implementation: https://github.com/lasgroup/SDPO
+4. Modal GPU Memory Snapshots: https://modal.com/blog/gpu-mem-snapshots
+5. vLLM: https://github.com/vllm-project/vllm
+6. PEFT/LoRA: https://github.com/huggingface/peft
+7. Tinker SDPO training reference (continualcode): https://github.com/sdan/continualcode
## License
diff --git a/claas/api.py b/claas/api.py
index c67aeab..5284ed9 100644
--- a/claas/api.py
+++ b/claas/api.py
@@ -232,6 +232,8 @@ async def chat_completions(
{"role": m.role, "content": coerce_content(m.content)}
for m in req.messages
]
+ system_parts = [m["content"] for m in messages if m["role"] == "system"]
+ system_prompt = "\n".join(system_parts) if system_parts else ""
result = await backend.chat_completion(
messages=messages,
model=req.model or "default",
@@ -259,6 +261,7 @@ async def chat_completions(
response_token_ids=result.response_token_ids,
prompt_token_ids=result.prompt_token_ids,
response_logprobs=result.response_logprobs,
+ system_prompt=system_prompt,
),
)
@@ -394,6 +397,15 @@ async def feedback(request: FeedbackBatchRequest) -> FeedbackResponse:
status_code=422,
detail=f"Cached completion has no logprobs (content_hash={content_hash[:16]}…)",
)
+ if not entry.system_prompt:
+ raise HTTPException(
+ status_code=422,
+ detail=(
+ f"Cached completion has no system prompt (content_hash={content_hash[:16]}…). "
+ "The chat completion request must include a system message so the teacher "
+ "can score under the same context as the student."
+ ),
+ )
batch_samples.append(
DistillBatchItem(
prompt=entry.prompt,
@@ -403,6 +415,7 @@ async def feedback(request: FeedbackBatchRequest) -> FeedbackResponse:
prompt_token_ids=entry.prompt_token_ids,
response_token_ids=entry.response_token_ids,
user_prompt=req.prompt,
+ system_prompt=entry.system_prompt,
)
)
diff --git a/claas/core/types.py b/claas/core/types.py
index bca5151..419265f 100644
--- a/claas/core/types.py
+++ b/claas/core/types.py
@@ -120,6 +120,12 @@ class DistillBatchItem(BaseModel):
"not a nested chat template."
),
)
+ system_prompt: str = Field(
+ description=(
+ "System prompt from the chat completion request. "
+ "Passed to the teacher so it scores under the same context as the student."
+ ),
+ )
class DistillBatchRequestPayload(BaseModel):
diff --git a/claas/dashboard/rendering.py b/claas/dashboard/rendering.py
index 3ac26ea..536c717 100644
--- a/claas/dashboard/rendering.py
+++ b/claas/dashboard/rendering.py
@@ -77,7 +77,15 @@ def feedback_dashboard_rows(records: list[FeedbackLogRecord]) -> str:
# -- Expandable detail row --
sample_sections: list[str] = []
+ _raw_texts = metrics_payload.get("teacher_scored_texts") if metrics_payload else None
+ teacher_scored_texts = list(_raw_texts) if isinstance(_raw_texts, list) else []
for item_index, sample in enumerate(record.batch_samples):
+ teacher_section = ""
+ if item_index < len(teacher_scored_texts):
+ teacher_section = (
+ "Teacher Scored Text
"
+ "{teacher_text}"
+ ).format(teacher_text=html.escape(str(teacher_scored_texts[item_index])))
sample_sections.append(
"""
@@ -86,6 +94,7 @@ def feedback_dashboard_rows(records: list[FeedbackLogRecord]) -> str:
+ {teacher_section}
""".format(
@@ -96,6 +105,7 @@ def feedback_dashboard_rows(records: list[FeedbackLogRecord]) -> str:
prompt=html.escape(sample.prompt),
response=html.escape(sample.response),
feedback=html.escape(sample.feedback),
+ teacher_section=teacher_section,
)
)
diff --git a/claas/eval/runner.py b/claas/eval/runner.py
index 3912e10..c449d0d 100644
--- a/claas/eval/runner.py
+++ b/claas/eval/runner.py
@@ -40,7 +40,8 @@
async def _init_lora(config: HarnessConfig, lora_id: str) -> str:
"""Initialize a fresh LoRA adapter via CLaaS API."""
- async with httpx.AsyncClient(base_url=config.claas_url, timeout=120.0) as client:
+ # LoRA init can exceed two minutes when the remote trainer is cold-starting.
+ async with httpx.AsyncClient(base_url=config.claas_url, timeout=300.0) as client:
resp = await client.post(
"/v1/lora/init",
json={"lora_id": lora_id, "base_model": config.base_model},
diff --git a/claas/inference/cache.py b/claas/inference/cache.py
index f7a5058..2f0f454 100644
--- a/claas/inference/cache.py
+++ b/claas/inference/cache.py
@@ -12,7 +12,7 @@
class CompletionCacheEntry:
"""A single cached completion with prompt, response, token IDs, and logprobs."""
- __slots__ = ("prompt", "response", "response_token_ids", "prompt_token_ids", "response_logprobs", "created_at")
+ __slots__ = ("prompt", "response", "response_token_ids", "prompt_token_ids", "response_logprobs", "system_prompt", "created_at")
def __init__(
self,
@@ -21,12 +21,14 @@ def __init__(
response_token_ids: list[int],
prompt_token_ids: list[int],
response_logprobs: list[float] | None,
+ system_prompt: str,
) -> None:
self.prompt = prompt
self.response = response
self.response_token_ids = response_token_ids
self.prompt_token_ids = prompt_token_ids
self.response_logprobs = response_logprobs
+ self.system_prompt = system_prompt
self.created_at = time.monotonic()
def is_expired(self) -> bool:
diff --git a/claas/training/distillation.py b/claas/training/distillation.py
index b64c964..a46ee89 100644
--- a/claas/training/distillation.py
+++ b/claas/training/distillation.py
@@ -167,7 +167,9 @@ def _build_self_teacher_topk(
feedback: str,
response_ids: "torch.Tensor",
top_k: int,
- ) -> tuple["torch.Tensor", "torch.Tensor"]:
+ *,
+ system_prompt: str,
+ ) -> tuple["torch.Tensor", "torch.Tensor", str]:
"""Build top-k teacher logits from the frozen base model.
Args:
@@ -177,18 +179,18 @@ def _build_self_teacher_topk(
top_k: Number of logits to retain per token.
Returns:
- Pair of top-k logprobs and indices for each response token.
+ Triple of (top-k logprobs, top-k indices, teacher_scored_text).
Reference:
- Huebotter et al. (2026), "Reinforcement Learning via Self-Distillation"
- (arXiv:2601.20802), Section 3.
+ Kleine Buening et al. (2026), "Aligning Language Models from User Interactions"
+ https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py
"""
- messages = build_teacher_messages(prompt, feedback)
+ messages = build_teacher_messages(prompt, feedback, system_prompt=system_prompt)
template_messages = teacher_messages_to_chat_template(messages)
teacher_prompt_ids_raw = self.tokenizer.apply_chat_template(
template_messages,
- add_generation_prompt=False,
+ add_generation_prompt=True,
return_tensors="pt",
tokenize=True,
)
@@ -209,9 +211,10 @@ def _build_self_teacher_topk(
top_logprobs, top_indices = torch.topk(log_probs[0, :response_token_count], k=k, dim=-1)
del teacher_output, teacher_logits, log_probs
+ teacher_scored_text = self.tokenizer.decode(teacher_full_ids[0].tolist(), skip_special_tokens=False)
del teacher_full_ids, teacher_prompt_ids
torch.cuda.empty_cache()
- return top_logprobs, top_indices
+ return top_logprobs, top_indices, teacher_scored_text
def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
"""Run one SDPO distillation step.
@@ -256,6 +259,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
batch_kl_reg: list[float] = []
batch_mean_is_ratio: list[float] = []
batch_clip_fraction: list[float] = []
+ batch_teacher_scored_texts: list[str] = []
tokens_processed = 0
for sample in payload.samples:
@@ -302,11 +306,12 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
elif old_student_logprobs.shape[1] < response_token_count:
raise ValueError("response_logprobs length must match response token length")
- teacher_logprobs, teacher_indices = self._build_self_teacher_topk(
+ teacher_logprobs, teacher_indices, teacher_scored_text = self._build_self_teacher_topk(
sample.user_prompt,
sample.feedback,
response_ids,
config.teacher_top_k,
+ system_prompt=sample.system_prompt,
)
if teacher_logprobs.shape[0] != response_token_count:
@@ -330,6 +335,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
batch_kl_reg.append(loss_dict["kl_reg"])
batch_mean_is_ratio.append(loss_dict["mean_is_ratio"])
batch_clip_fraction.append(loss_dict["clip_fraction"])
+ batch_teacher_scored_texts.append(teacher_scored_text)
del full_ids, prompt_ids, response_ids, response_mask
del student_logits, base_logprobs, old_student_logprobs
@@ -377,6 +383,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse:
"grad_norm": grad_norm_value,
"tokens_processed": tokens_processed,
"batch_size": len(payload.samples),
+ "teacher_scored_texts": batch_teacher_scored_texts,
},
}
)
diff --git a/claas/training/engine/tinker/engine.py b/claas/training/engine/tinker/engine.py
index af45737..c5db847 100644
--- a/claas/training/engine/tinker/engine.py
+++ b/claas/training/engine/tinker/engine.py
@@ -252,6 +252,7 @@ async def distill(
"lr": lr,
"loss_fn": "importance_sampling",
"timestamp": datetime.now(timezone.utc).isoformat(),
+ "teacher_scored_texts": [m["teacher_scored_text"] for m in sample_metrics],
}
if hasattr(fwd_bwd, "metrics") and fwd_bwd.metrics:
@@ -265,7 +266,7 @@ async def _build_sample_datum(
tokenizer: Any,
teacher_sampling: Any,
kl_coef: float,
-) -> tuple[T.Datum, dict[str, float]]:
+) -> tuple[T.Datum, dict[str, float | str]]:
"""Process a single sample into a Datum and per-sample metrics.
This is the per-sample logic extracted from ``distill()`` so that
@@ -288,11 +289,13 @@ async def _build_sample_datum(
# ── Build teacher prompt (matching local worker: build_teacher_messages) ──
teacher_prompt_source = sample.user_prompt
- teacher_messages = build_teacher_messages(teacher_prompt_source, sample.feedback)
+ teacher_messages = build_teacher_messages(
+ teacher_prompt_source, sample.feedback, system_prompt=sample.system_prompt
+ )
template_messages = teacher_messages_to_chat_template(teacher_messages)
teacher_prompt_text = tokenizer.apply_chat_template(
template_messages,
- add_generation_prompt=False,
+ add_generation_prompt=True,
tokenize=False,
)
teacher_prompt_tokens: list[int] = tokenizer.encode(
@@ -302,6 +305,7 @@ async def _build_sample_datum(
teacher_full_tokens = teacher_prompt_tokens + response_tokens
teacher_prompt_len = len(teacher_prompt_tokens)
teacher_full = T.ModelInput.from_ints(teacher_full_tokens)
+ teacher_scored_text = tokenizer.decode(teacher_full_tokens, skip_special_tokens=False)
# ── Compute teacher logprobs (base model = self-distillation) ──
teacher_logprobs_full = await teacher_sampling.compute_logprobs_async(teacher_full)
@@ -346,7 +350,7 @@ async def _build_sample_datum(
adv_abs_mean = sum(abs(a) for a in advantages) / max(len(advantages), 1)
kl_mean = sum(raw_kl_deltas) / max(len(raw_kl_deltas), 1)
- metrics = {
+ metrics: dict[str, float | str] = {
"completion_len": completion_len,
"effective_kl_coef": effective_kl_coef,
"kl_gain": gain,
@@ -354,6 +358,7 @@ async def _build_sample_datum(
"adv_abs_mean": adv_abs_mean,
"kl_mean": kl_mean,
"adv_abs_mean_raw": adv_abs_mean_raw,
+ "teacher_scored_text": teacher_scored_text,
}
return datum, metrics
diff --git a/claas/training/teacher_helpers.py b/claas/training/teacher_helpers.py
index 85a93ea..5f0d8ab 100644
--- a/claas/training/teacher_helpers.py
+++ b/claas/training/teacher_helpers.py
@@ -1,43 +1,51 @@
"""Pure helper functions for teacher prompt building.
These functions have no Modal dependency and can be imported anywhere.
+
+Reference:
+ Kleine Buening et al. (2026), "Aligning Language Models from User Interactions"
+ https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py
"""
from __future__ import annotations
-from claas.core.config import DEFAULT_SYSTEM_PROMPT
from claas.core.types import ChatMessage
def build_teacher_messages(
prompt: str,
feedback: str | None = None,
- system_prompt: str | None = None,
+ *,
+ system_prompt: str,
) -> list[ChatMessage]:
- """Build chat messages for teacher prompt (veRL-compatible template).
+ """Build chat messages for the hindsight policy (SDPO teacher).
+
+ Appends a hindsight context block to the user prompt so the frozen base
+ model can score the student's response tokens with awareness of the user's
+ follow-up. Template follows the online SDPO trainer from:
+ https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py
- Template matches veRL's reprompt_template structure:
- {prompt}{feedback}\\n\\nCorrectly solve the original question.
+ The system_prompt must match the one the student saw so the teacher scores
+ under the same context.
Args:
prompt: The original user prompt
- feedback: Optional feedback about the response quality
- system_prompt: Optional system prompt
+ feedback: Optional user follow-up / feedback about the response
+ system_prompt: System prompt (must match the student's)
Returns:
List of message dicts with 'role' and 'content' keys
"""
- if system_prompt is None:
- system_prompt = DEFAULT_SYSTEM_PROMPT
messages: list[ChatMessage] = [{"role": "system", "content": system_prompt}]
- # Build user content with veRL-style template
if feedback:
- feedback_section = (
- "\n\nThe following is positive or negative feedback from your earlier attempt:"
- f"\n\n{feedback}\n"
+ block = (
+ "\n\n=== HINDSIGHT CONTEXT ===\n"
+ "[The following is a future user message. "
+ "Use this to guide your answer to the user prompt.]\n"
+ f"{feedback.strip()}"
)
- user_content = f"{prompt}{feedback_section}\n\nCorrectly solve the original question.\n"
+ user_content = f"{prompt}{block}"
else:
user_content = prompt
diff --git a/docker/scripts/init-stack.py b/docker/scripts/init-stack.py
index 35fe431..cf654b8 100755
--- a/docker/scripts/init-stack.py
+++ b/docker/scripts/init-stack.py
@@ -260,6 +260,7 @@ def write_openclaw_config() -> None:
"port": 18789,
"mode": "local",
"bind": "loopback",
+ "controlUi": {"enabled": False},
"http": {
"endpoints": {
"chatCompletions": {"enabled": True},
diff --git a/tests/integration/run_tinker_stack_integration.sh b/tests/integration/run_tinker_stack_integration.sh
index 54cca83..3e36791 100755
--- a/tests/integration/run_tinker_stack_integration.sh
+++ b/tests/integration/run_tinker_stack_integration.sh
@@ -27,7 +27,13 @@ compose() {
docker compose -f "$COMPOSE_FILE" --profile tinker "$@"
}
+dump_logs() {
+ echo "==> Docker container logs (last 200 lines each) ..."
+ compose logs --tail=200 || true
+}
+
cleanup() {
+ dump_logs
echo "==> Tearing down tinker stack ..."
compose down -v || true
}
diff --git a/tests/integration/test_local_engine_integration.py b/tests/integration/test_local_engine_integration.py
index aefbd4a..06e945c 100644
--- a/tests/integration/test_local_engine_integration.py
+++ b/tests/integration/test_local_engine_integration.py
@@ -107,6 +107,7 @@ def test_local_engine_integration_paths(monkeypatch, tmp_path):
prompt_token_ids=[1, 2],
response_token_ids=[3],
user_prompt="prompt",
+ system_prompt="You are a helpful assistant.",
)
],
)
@@ -158,6 +159,7 @@ def test_local_engine_cleanup_failure_propagates(monkeypatch):
prompt_token_ids=[1, 2],
response_token_ids=[3],
user_prompt="prompt",
+ system_prompt="You are a helpful assistant.",
)
],
)
diff --git a/tests/integration/test_tinker_stack_integration.py b/tests/integration/test_tinker_stack_integration.py
index aa18f46..926aa6e 100644
--- a/tests/integration/test_tinker_stack_integration.py
+++ b/tests/integration/test_tinker_stack_integration.py
@@ -112,7 +112,11 @@ def test_full_lifecycle(self, tinker_stack: TinkerStack) -> None:
assert lora_id in list_resp.json()["loras"]
# 3. Inference through CLaaS API
- chat_messages = [{"role": "user", "content": user_prompt}]
+ system_prompt = "You are a helpful assistant."
+ chat_messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+ ]
chat_payload = {
"model": tinker_stack.model,
"messages": chat_messages,
@@ -136,7 +140,7 @@ def test_full_lifecycle(self, tinker_stack: TinkerStack) -> None:
# 4. Distill via feedback endpoint (teacher_mode=self)
# The API resolves prompt, response, and logprobs from
# the completion cache by hashing the response text.
- teacher_messages = build_teacher_messages(user_prompt, feedback_text)
+ teacher_messages = build_teacher_messages(user_prompt, feedback_text, system_prompt=system_prompt)
logger.info(
"Teacher messages (built by engine for self-distillation):\n%s",
json.dumps(teacher_messages, indent=2),
diff --git a/tests/test_api.py b/tests/test_api.py
index 3ca00c5..678ad81 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -108,7 +108,15 @@ def __init__(self):
self.remote = _RemoteCallFailure()
-def _seed_cache(visible_response: str, *, logprobs: list[float] | None = None) -> None:
+_DEFAULT_TEST_SYSTEM_PROMPT = "You are a helpful assistant."
+
+
+def _seed_cache(
+ visible_response: str,
+ *,
+ logprobs: list[float] | None = None,
+ system_prompt: str = _DEFAULT_TEST_SYSTEM_PROMPT,
+) -> None:
"""Pre-populate the completion cache keyed by normalized SHA-256 of visible_response."""
content_hash = hashlib.sha256(normalize_for_hash(visible_response).encode("utf-8")).hexdigest()
completion_cache.put(
@@ -119,6 +127,7 @@ def _seed_cache(visible_response: str, *, logprobs: list[float] | None = None) -
response_token_ids=[10, 20],
prompt_token_ids=[1, 2, 3],
response_logprobs=logprobs if logprobs is not None else [-0.1, -0.2],
+ system_prompt=system_prompt,
),
)
@@ -455,6 +464,83 @@ async def fake_run_distill(payload):
assert sample["response"] == "cached-response"
+def test_feedback_system_prompt_flows_from_cache(monkeypatch, tmp_path):
+ """system_prompt is resolved from the completion cache into DistillBatchItem."""
+ from claas import api
+
+ _mock_config(monkeypatch, "tinker")
+ _seed_cache("sys prompt response", system_prompt="You are a pirate assistant.")
+ captured_payload = {}
+
+ class _Engine:
+ async def lora_exists(self, _lora_id):
+ return LoraExistsPayload(exists=True)
+
+ async def fake_run_distill(payload):
+ captured_payload["samples"] = [s.model_dump() for s in payload.samples]
+ return DistillResponse(lora_id="user/model", metadata={})
+
+ monkeypatch.setattr(api, "_get_training_engine", lambda: _Engine())
+ monkeypatch.setattr(api, "_run_distill", fake_run_distill)
+ monkeypatch.setattr(feedback_log_mod, "write_feedback_log", lambda _r, _d: str(tmp_path / "log.json"))
+
+ client = TestClient(web_app)
+ response = client.post(
+ "/v1/feedback",
+ json={
+ "requests": [
+ {
+ "lora_id": "user/model",
+ "prompt": "clean prompt",
+ "response": "sys prompt response",
+ "feedback": "good",
+ "training": {},
+ }
+ ],
+ "orchestration": {"sleep_before": False, "wake_after": False},
+ },
+ )
+
+ assert response.status_code == 200
+ sample = captured_payload["samples"][0]
+ assert sample["system_prompt"] == "You are a pirate assistant."
+
+
+def test_feedback_missing_system_prompt_returns_422(monkeypatch, tmp_path):
+ """Cache entry with empty system_prompt returns 422."""
+ from claas import api
+
+ _mock_config(monkeypatch, "tinker")
+ _seed_cache("no sys prompt response", system_prompt="")
+
+ class _Engine:
+ async def lora_exists(self, _lora_id):
+ return LoraExistsPayload(exists=True)
+
+ monkeypatch.setattr(api, "_get_training_engine", lambda: _Engine())
+ monkeypatch.setattr(feedback_log_mod, "write_feedback_log", lambda _r, _d: str(tmp_path / "log.json"))
+
+ client = TestClient(web_app)
+ response = client.post(
+ "/v1/feedback",
+ json={
+ "requests": [
+ {
+ "lora_id": "user/model",
+ "prompt": "clean prompt",
+ "response": "no sys prompt response",
+ "feedback": "good",
+ "training": {},
+ }
+ ],
+ "orchestration": {"sleep_before": False, "wake_after": False},
+ },
+ )
+
+ assert response.status_code == 422
+ assert "no system prompt" in response.json()["detail"]
+
+
def test_feedback_cache_miss_returns_404(monkeypatch, tmp_path):
"""Missing cache entry returns 404 before any orchestration."""
from claas import api
@@ -505,6 +591,7 @@ def test_feedback_missing_logprobs_returns_422(monkeypatch, tmp_path):
response_token_ids=[10],
prompt_token_ids=[1],
response_logprobs=None,
+ system_prompt="You are a helpful assistant.",
),
)
@@ -834,7 +921,8 @@ def test_dashboard_renders_latest_records(monkeypatch, tmp_path):
"response_logprobs": [-0.1],
"prompt_token_ids": [1],
"response_token_ids": [2],
- "user_prompt": "prompt-a"
+ "user_prompt": "prompt-a",
+ "system_prompt": "You are a helpful assistant."
}
],
"vllm": {
@@ -917,7 +1005,8 @@ def test_dashboard_truncates_prompt_preview(monkeypatch, tmp_path):
"response_logprobs": [-0.1],
"prompt_token_ids": [1],
"response_token_ids": [2],
- "user_prompt": "{long_prompt}"
+ "user_prompt": "{long_prompt}",
+ "system_prompt": "You are a helpful assistant."
}}
],
"vllm": {{
@@ -985,7 +1074,8 @@ def test_dashboard_renders_one_row_per_batch_item(monkeypatch, tmp_path):
"response_logprobs": [-0.1],
"prompt_token_ids": [1],
"response_token_ids": [2],
- "user_prompt": "prompt-d1"
+ "user_prompt": "prompt-d1",
+ "system_prompt": "You are a helpful assistant."
},
{
"prompt": "prompt-d2",
@@ -994,7 +1084,8 @@ def test_dashboard_renders_one_row_per_batch_item(monkeypatch, tmp_path):
"response_logprobs": [-0.2],
"prompt_token_ids": [1],
"response_token_ids": [2],
- "user_prompt": "prompt-d2"
+ "user_prompt": "prompt-d2",
+ "system_prompt": "You are a helpful assistant."
}
],
"vllm": {
@@ -1034,6 +1125,132 @@ def test_dashboard_renders_one_row_per_batch_item(monkeypatch, tmp_path):
assert "2 samples" in response.text
+def test_dashboard_renders_teacher_scored_text(monkeypatch, tmp_path):
+ """Dashboard shows Teacher Scored Text when present in distill metadata."""
+ _mock_config(monkeypatch, "local", feedback_log_dir=str(tmp_path))
+ (tmp_path / "20240101T000010-t.json").write_text(
+ """
+{
+ "request_id": "t",
+ "timestamp_utc": "2024-01-01T00:00:10Z",
+ "status": "ok",
+ "phase": "done",
+ "lora_id": "user/model",
+ "requests": [
+ {
+ "lora_id": "user/model",
+ "prompt": "prompt-t",
+ "response": "response-t",
+ "feedback": "feedback-t"
+ }
+ ],
+ "batch_samples": [
+ {
+ "prompt": "prompt-t",
+ "response": "response-t",
+ "feedback": "feedback-t",
+ "response_logprobs": [-0.1],
+ "prompt_token_ids": [1],
+ "response_token_ids": [2],
+ "user_prompt": "prompt-t",
+ "system_prompt": "You are a helpful assistant."
+ }
+ ],
+ "vllm": {
+ "slept": false,
+ "woke": false
+ },
+ "timing_ms": {
+ "sleep": 0,
+ "distill": 2,
+ "save": 0,
+ "wake": 0,
+ "logprobs": 0,
+ "total": 2
+ },
+ "distill_result": {
+ "lora_id": "user/model",
+ "metadata": {
+ "loss": 0.1,
+ "teacher_scored_texts": ["<|im_start|>system\\nYou are helpful<|im_end|>\\n<|im_start|>user\\nWhat is 2+2?<|im_end|>\\n<|im_start|>assistant\\nFour"]
+ }
+ },
+ "error": null
+}
+""".strip(),
+ encoding="utf-8",
+ )
+
+ client = TestClient(web_app)
+ response = client.get("/v1/dashboard")
+
+ assert response.status_code == 200
+ assert "Teacher Scored Text" in response.text
+ assert "What is 2+2?" in response.text
+
+
+def test_dashboard_omits_teacher_text_when_absent(monkeypatch, tmp_path):
+ """Dashboard gracefully handles old logs without teacher_scored_texts."""
+ _mock_config(monkeypatch, "local", feedback_log_dir=str(tmp_path))
+ (tmp_path / "20240101T000011-u.json").write_text(
+ """
+{
+ "request_id": "u",
+ "timestamp_utc": "2024-01-01T00:00:11Z",
+ "status": "ok",
+ "phase": "done",
+ "lora_id": "user/model",
+ "requests": [
+ {
+ "lora_id": "user/model",
+ "prompt": "prompt-u",
+ "response": "response-u",
+ "feedback": "feedback-u"
+ }
+ ],
+ "batch_samples": [
+ {
+ "prompt": "prompt-u",
+ "response": "response-u",
+ "feedback": "feedback-u",
+ "response_logprobs": [-0.1],
+ "prompt_token_ids": [1],
+ "response_token_ids": [2],
+ "user_prompt": "prompt-u",
+ "system_prompt": "You are a helpful assistant."
+ }
+ ],
+ "vllm": {
+ "slept": false,
+ "woke": false
+ },
+ "timing_ms": {
+ "sleep": 0,
+ "distill": 2,
+ "save": 0,
+ "wake": 0,
+ "logprobs": 0,
+ "total": 2
+ },
+ "distill_result": {
+ "lora_id": "user/model",
+ "metadata": {
+ "loss": 0.1
+ }
+ },
+ "error": null
+}
+""".strip(),
+ encoding="utf-8",
+ )
+
+ client = TestClient(web_app)
+ response = client.get("/v1/dashboard")
+
+ assert response.status_code == 200
+ assert "Teacher Scored Text" not in response.text
+
+
def test_feedback_recent_route_is_removed():
client = TestClient(web_app)
response = client.get("/v1/feedback/recent")
diff --git a/tests/test_local_training_engine.py b/tests/test_local_training_engine.py
index 0a2d764..3d72027 100644
--- a/tests/test_local_training_engine.py
+++ b/tests/test_local_training_engine.py
@@ -54,6 +54,7 @@ def test_local_engine_distill_propagates_cleanup_error(monkeypatch):
prompt_token_ids=[1, 2],
response_token_ids=[3],
user_prompt="p",
+ system_prompt="You are a helpful assistant.",
)
],
)
diff --git a/tests/test_teacher_helpers.py b/tests/test_teacher_helpers.py
index 295b2b9..1aa9223 100644
--- a/tests/test_teacher_helpers.py
+++ b/tests/test_teacher_helpers.py
@@ -2,6 +2,8 @@
from __future__ import annotations
+import pytest
+
from claas.core.types import ChatMessage
from claas.training.teacher_helpers import (
build_teacher_messages,
@@ -11,30 +13,32 @@
class TestBuildTeacherMessages:
def test_with_feedback(self):
- msgs = build_teacher_messages("What is 2+2?", "Too verbose")
+ msgs = build_teacher_messages("What is 2+2?", "Too verbose", system_prompt="You are helpful.")
assert len(msgs) == 2
assert msgs[0]["role"] == "system"
+ assert msgs[0]["content"] == "You are helpful."
assert msgs[1]["role"] == "user"
assert "What is 2+2?" in msgs[1]["content"]
assert "Too verbose" in msgs[1]["content"]
- assert "Correctly solve the original question" in msgs[1]["content"]
+ assert "=== HINDSIGHT CONTEXT ===" in msgs[1]["content"]
+ assert "Use this to guide your answer to the user prompt" in msgs[1]["content"]
def test_without_feedback(self):
- msgs = build_teacher_messages("What is 2+2?")
+ msgs = build_teacher_messages("What is 2+2?", system_prompt="You are helpful.")
assert len(msgs) == 2
assert msgs[1]["content"] == "What is 2+2?"
def test_without_feedback_explicit_none(self):
- msgs = build_teacher_messages("Hello", None)
+ msgs = build_teacher_messages("Hello", None, system_prompt="You are helpful.")
assert msgs[1]["content"] == "Hello"
- def test_custom_system_prompt(self):
- msgs = build_teacher_messages("x", system_prompt="Be concise.")
- assert msgs[0]["content"] == "Be concise."
-
- def test_default_system_prompt(self):
- msgs = build_teacher_messages("x")
- assert "helpful assistant" in msgs[0]["content"]
+ @pytest.mark.parametrize("system_prompt", [
+ "Be concise.",
+ "You are a pirate assistant who speaks in pirate lingo.",
+ ])
+ def test_system_prompt_matches_input(self, system_prompt: str):
+ msgs = build_teacher_messages("x", system_prompt=system_prompt)
+ assert msgs[0]["content"] == system_prompt
class TestTeacherMessagesToChatTemplate:
@@ -50,5 +54,3 @@ def test_converts_typed_messages(self):
]
# Returns new list (not the same object)
assert result is not msgs
-
-
diff --git a/tests/test_tinker_engine.py b/tests/test_tinker_engine.py
index b0818dc..7dd35ca 100644
--- a/tests/test_tinker_engine.py
+++ b/tests/test_tinker_engine.py
@@ -84,6 +84,7 @@ def mock_tokenizer():
tok = MagicMock()
tok.encode.side_effect = lambda text, **kw: list(range(len(text)))
tok.apply_chat_template.return_value = "chat-template-output"
+ tok.decode.side_effect = lambda ids, **kw: f"decoded:{ids}"
return tok
@@ -417,6 +418,7 @@ def test_engine_distill_full_flow(tinker_engine, mock_training_client):
prompt_token_ids=[0, 1, 2, 3, 4],
response_token_ids=[0, 1, 2, 3, 4],
user_prompt="Hello",
+ system_prompt="You are a helpful assistant.",
)
],
)
@@ -436,6 +438,8 @@ def test_engine_distill_full_flow(tinker_engine, mock_training_client):
assert "adv_mean" in result.metadata
assert result.metadata["loss_fn"] == "importance_sampling"
assert result.metadata["tinker_fwd_metrics"] == {"loss": 0.42}
+ assert "teacher_scored_texts" in result.metadata
+ assert len(result.metadata["teacher_scored_texts"]) == 1
# Verify the training client was restored from the init checkpoint
mock_service.create_training_client_from_state_async.assert_called_once_with(
@@ -493,6 +497,7 @@ def test_engine_distill_uses_provided_response_logprobs(tinker_engine, mock_trai
prompt_token_ids=[0, 1, 2, 3, 4],
response_token_ids=[0, 1, 2, 3, 4],
user_prompt="Hello",
+ system_prompt="You are a helpful assistant.",
)
],
)
@@ -538,6 +543,7 @@ def test_engine_distill_batch_multiple_samples(tinker_engine, mock_training_clie
prompt_token_ids=[0, 1, 2, 3, 4],
response_token_ids=[0, 1, 2, 3, 4],
user_prompt="Hello",
+ system_prompt="You are a helpful assistant.",
),
# Sample 2: 5 tokens
DistillBatchItem(
@@ -548,6 +554,7 @@ def test_engine_distill_batch_multiple_samples(tinker_engine, mock_training_clie
prompt_token_ids=[0, 1],
response_token_ids=[0, 1, 2, 3, 4],
user_prompt="Hi",
+ system_prompt="You are a helpful assistant.",
),
# Sample 3: 3 tokens
DistillBatchItem(
@@ -558,6 +565,7 @@ def test_engine_distill_batch_multiple_samples(tinker_engine, mock_training_clie
prompt_token_ids=[0, 1, 2],
response_token_ids=[0, 1, 2],
user_prompt="Hey",
+ system_prompt="You are a helpful assistant.",
),
]
@@ -579,6 +587,10 @@ def test_engine_distill_batch_multiple_samples(tinker_engine, mock_training_clie
# 5 + 5 + 3 = 13
assert result.metadata["completion_len"] == 13
+ # Teacher scored texts: one per sample
+ assert "teacher_scored_texts" in result.metadata
+ assert len(result.metadata["teacher_scored_texts"]) == 3
+
# Averaged metrics are present
assert "adv_mean" in result.metadata
assert "adv_abs_mean" in result.metadata
@@ -713,6 +725,7 @@ def test_engine_distill_uses_response_token_ids(tinker_engine, mock_training_cli
prompt_token_ids=prompt_ids,
response_token_ids=response_ids,
user_prompt="What is 2+2?",
+ system_prompt="You are a helpful assistant.",
)
],
)
@@ -757,13 +770,14 @@ def test_engine_distill_uses_user_prompt_for_teacher(tinker_engine, mock_trainin
prompt_token_ids=[10, 20, 30],
response_token_ids=[40, 50, 60, 70],
user_prompt="What is 2+2?",
+ system_prompt="You are a pirate.",
)
],
)
with patch("claas.training.engine.tinker.engine.build_teacher_messages") as mock_build:
mock_build.return_value = [
- {"role": "system", "content": "You are helpful"},
+ {"role": "system", "content": "You are a pirate."},
{"role": "user", "content": "What is 2+2?"},
]
asyncio.run(engine.distill(payload))
@@ -772,3 +786,11 @@ def test_engine_distill_uses_user_prompt_for_teacher(tinker_engine, mock_trainin
called_prompt = mock_build.call_args[0][0]
assert called_prompt == "What is 2+2?"
assert "<|im_start|>" not in called_prompt
+ # system_prompt is forwarded to the teacher
+ assert mock_build.call_args[1]["system_prompt"] == "You are a pirate."
+
+ # Teacher prompt must use add_generation_prompt=True to include the
+ # assistant turn delimiter before response tokens.
+ tok = mock_training_client.get_tokenizer()
+ tok.apply_chat_template.assert_called_once()
+ assert tok.apply_chat_template.call_args[1]["add_generation_prompt"] is True