diff --git a/README.md b/README.md index 76ae148..1f1bc44 100644 --- a/README.md +++ b/README.md @@ -167,12 +167,13 @@ The eval dashboard at `/v1/eval` displays results from running the eval harness ## References -1. Hübotter et al. (2026). "Reinforcement Learning via Self-Distillation." arXiv:2601.20802 -2. SDPO Reference Implementation: https://github.com/lasgroup/SDPO -3. Modal GPU Memory Snapshots: https://modal.com/blog/gpu-mem-snapshots -4. vLLM: https://github.com/vllm-project/vllm -5. PEFT/LoRA: https://github.com/huggingface/peft -6. Tinker SDPO training reference (continualcode): https://github.com/sdan/continualcode +1. Kleine Buening et al. (2026). "Aligning Language Models from User Interactions." [Paper](https://self-distillation.github.io/user_interactions.pdf) · [Code](https://github.com/lasgroup/user_interactions) · [Hindsight template](https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py) +2. Hübotter et al. (2026). "Reinforcement Learning via Self-Distillation." arXiv:2601.20802 +3. SDPO Reference Implementation: https://github.com/lasgroup/SDPO +4. Modal GPU Memory Snapshots: https://modal.com/blog/gpu-mem-snapshots +5. vLLM: https://github.com/vllm-project/vllm +6. PEFT/LoRA: https://github.com/huggingface/peft +7. Tinker SDPO training reference (continualcode): https://github.com/sdan/continualcode ## License diff --git a/claas/api.py b/claas/api.py index c67aeab..5284ed9 100644 --- a/claas/api.py +++ b/claas/api.py @@ -232,6 +232,8 @@ async def chat_completions( {"role": m.role, "content": coerce_content(m.content)} for m in req.messages ] + system_parts = [m["content"] for m in messages if m["role"] == "system"] + system_prompt = "\n".join(system_parts) if system_parts else "" result = await backend.chat_completion( messages=messages, model=req.model or "default", @@ -259,6 +261,7 @@ async def chat_completions( response_token_ids=result.response_token_ids, prompt_token_ids=result.prompt_token_ids, response_logprobs=result.response_logprobs, + system_prompt=system_prompt, ), ) @@ -394,6 +397,15 @@ async def feedback(request: FeedbackBatchRequest) -> FeedbackResponse: status_code=422, detail=f"Cached completion has no logprobs (content_hash={content_hash[:16]}…)", ) + if not entry.system_prompt: + raise HTTPException( + status_code=422, + detail=( + f"Cached completion has no system prompt (content_hash={content_hash[:16]}…). " + "The chat completion request must include a system message so the teacher " + "can score under the same context as the student." + ), + ) batch_samples.append( DistillBatchItem( prompt=entry.prompt, @@ -403,6 +415,7 @@ async def feedback(request: FeedbackBatchRequest) -> FeedbackResponse: prompt_token_ids=entry.prompt_token_ids, response_token_ids=entry.response_token_ids, user_prompt=req.prompt, + system_prompt=entry.system_prompt, ) ) diff --git a/claas/core/types.py b/claas/core/types.py index bca5151..419265f 100644 --- a/claas/core/types.py +++ b/claas/core/types.py @@ -120,6 +120,12 @@ class DistillBatchItem(BaseModel): "not a nested chat template." ), ) + system_prompt: str = Field( + description=( + "System prompt from the chat completion request. " + "Passed to the teacher so it scores under the same context as the student." + ), + ) class DistillBatchRequestPayload(BaseModel): diff --git a/claas/dashboard/rendering.py b/claas/dashboard/rendering.py index 3ac26ea..536c717 100644 --- a/claas/dashboard/rendering.py +++ b/claas/dashboard/rendering.py @@ -77,7 +77,15 @@ def feedback_dashboard_rows(records: list[FeedbackLogRecord]) -> str: # -- Expandable detail row -- sample_sections: list[str] = [] + _raw_texts = metrics_payload.get("teacher_scored_texts") if metrics_payload else None + teacher_scored_texts = list(_raw_texts) if isinstance(_raw_texts, list) else [] for item_index, sample in enumerate(record.batch_samples): + teacher_section = "" + if item_index < len(teacher_scored_texts): + teacher_section = ( + "

Teacher Scored Text

" + "
{teacher_text}
" + ).format(teacher_text=html.escape(str(teacher_scored_texts[item_index]))) sample_sections.append( """ @@ -86,6 +94,7 @@ def feedback_dashboard_rows(records: list[FeedbackLogRecord]) -> str:

Prompt

{prompt}

Response

{response}

Feedback

{feedback}
+ {teacher_section} """.format( @@ -96,6 +105,7 @@ def feedback_dashboard_rows(records: list[FeedbackLogRecord]) -> str: prompt=html.escape(sample.prompt), response=html.escape(sample.response), feedback=html.escape(sample.feedback), + teacher_section=teacher_section, ) ) diff --git a/claas/eval/runner.py b/claas/eval/runner.py index 3912e10..c449d0d 100644 --- a/claas/eval/runner.py +++ b/claas/eval/runner.py @@ -40,7 +40,8 @@ async def _init_lora(config: HarnessConfig, lora_id: str) -> str: """Initialize a fresh LoRA adapter via CLaaS API.""" - async with httpx.AsyncClient(base_url=config.claas_url, timeout=120.0) as client: + # LoRA init can exceed two minutes when the remote trainer is cold-starting. + async with httpx.AsyncClient(base_url=config.claas_url, timeout=300.0) as client: resp = await client.post( "/v1/lora/init", json={"lora_id": lora_id, "base_model": config.base_model}, diff --git a/claas/inference/cache.py b/claas/inference/cache.py index f7a5058..2f0f454 100644 --- a/claas/inference/cache.py +++ b/claas/inference/cache.py @@ -12,7 +12,7 @@ class CompletionCacheEntry: """A single cached completion with prompt, response, token IDs, and logprobs.""" - __slots__ = ("prompt", "response", "response_token_ids", "prompt_token_ids", "response_logprobs", "created_at") + __slots__ = ("prompt", "response", "response_token_ids", "prompt_token_ids", "response_logprobs", "system_prompt", "created_at") def __init__( self, @@ -21,12 +21,14 @@ def __init__( response_token_ids: list[int], prompt_token_ids: list[int], response_logprobs: list[float] | None, + system_prompt: str, ) -> None: self.prompt = prompt self.response = response self.response_token_ids = response_token_ids self.prompt_token_ids = prompt_token_ids self.response_logprobs = response_logprobs + self.system_prompt = system_prompt self.created_at = time.monotonic() def is_expired(self) -> bool: diff --git a/claas/training/distillation.py b/claas/training/distillation.py index b64c964..a46ee89 100644 --- a/claas/training/distillation.py +++ b/claas/training/distillation.py @@ -167,7 +167,9 @@ def _build_self_teacher_topk( feedback: str, response_ids: "torch.Tensor", top_k: int, - ) -> tuple["torch.Tensor", "torch.Tensor"]: + *, + system_prompt: str, + ) -> tuple["torch.Tensor", "torch.Tensor", str]: """Build top-k teacher logits from the frozen base model. Args: @@ -177,18 +179,18 @@ def _build_self_teacher_topk( top_k: Number of logits to retain per token. Returns: - Pair of top-k logprobs and indices for each response token. + Triple of (top-k logprobs, top-k indices, teacher_scored_text). Reference: - Huebotter et al. (2026), "Reinforcement Learning via Self-Distillation" - (arXiv:2601.20802), Section 3. + Kleine Buening et al. (2026), "Aligning Language Models from User Interactions" + https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py """ - messages = build_teacher_messages(prompt, feedback) + messages = build_teacher_messages(prompt, feedback, system_prompt=system_prompt) template_messages = teacher_messages_to_chat_template(messages) teacher_prompt_ids_raw = self.tokenizer.apply_chat_template( template_messages, - add_generation_prompt=False, + add_generation_prompt=True, return_tensors="pt", tokenize=True, ) @@ -209,9 +211,10 @@ def _build_self_teacher_topk( top_logprobs, top_indices = torch.topk(log_probs[0, :response_token_count], k=k, dim=-1) del teacher_output, teacher_logits, log_probs + teacher_scored_text = self.tokenizer.decode(teacher_full_ids[0].tolist(), skip_special_tokens=False) del teacher_full_ids, teacher_prompt_ids torch.cuda.empty_cache() - return top_logprobs, top_indices + return top_logprobs, top_indices, teacher_scored_text def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: """Run one SDPO distillation step. @@ -256,6 +259,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: batch_kl_reg: list[float] = [] batch_mean_is_ratio: list[float] = [] batch_clip_fraction: list[float] = [] + batch_teacher_scored_texts: list[str] = [] tokens_processed = 0 for sample in payload.samples: @@ -302,11 +306,12 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: elif old_student_logprobs.shape[1] < response_token_count: raise ValueError("response_logprobs length must match response token length") - teacher_logprobs, teacher_indices = self._build_self_teacher_topk( + teacher_logprobs, teacher_indices, teacher_scored_text = self._build_self_teacher_topk( sample.user_prompt, sample.feedback, response_ids, config.teacher_top_k, + system_prompt=sample.system_prompt, ) if teacher_logprobs.shape[0] != response_token_count: @@ -330,6 +335,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: batch_kl_reg.append(loss_dict["kl_reg"]) batch_mean_is_ratio.append(loss_dict["mean_is_ratio"]) batch_clip_fraction.append(loss_dict["clip_fraction"]) + batch_teacher_scored_texts.append(teacher_scored_text) del full_ids, prompt_ids, response_ids, response_mask del student_logits, base_logprobs, old_student_logprobs @@ -377,6 +383,7 @@ def distill(self, payload: DistillBatchRequestPayload) -> DistillResponse: "grad_norm": grad_norm_value, "tokens_processed": tokens_processed, "batch_size": len(payload.samples), + "teacher_scored_texts": batch_teacher_scored_texts, }, } ) diff --git a/claas/training/engine/tinker/engine.py b/claas/training/engine/tinker/engine.py index af45737..c5db847 100644 --- a/claas/training/engine/tinker/engine.py +++ b/claas/training/engine/tinker/engine.py @@ -252,6 +252,7 @@ async def distill( "lr": lr, "loss_fn": "importance_sampling", "timestamp": datetime.now(timezone.utc).isoformat(), + "teacher_scored_texts": [m["teacher_scored_text"] for m in sample_metrics], } if hasattr(fwd_bwd, "metrics") and fwd_bwd.metrics: @@ -265,7 +266,7 @@ async def _build_sample_datum( tokenizer: Any, teacher_sampling: Any, kl_coef: float, -) -> tuple[T.Datum, dict[str, float]]: +) -> tuple[T.Datum, dict[str, float | str]]: """Process a single sample into a Datum and per-sample metrics. This is the per-sample logic extracted from ``distill()`` so that @@ -288,11 +289,13 @@ async def _build_sample_datum( # ── Build teacher prompt (matching local worker: build_teacher_messages) ── teacher_prompt_source = sample.user_prompt - teacher_messages = build_teacher_messages(teacher_prompt_source, sample.feedback) + teacher_messages = build_teacher_messages( + teacher_prompt_source, sample.feedback, system_prompt=sample.system_prompt + ) template_messages = teacher_messages_to_chat_template(teacher_messages) teacher_prompt_text = tokenizer.apply_chat_template( template_messages, - add_generation_prompt=False, + add_generation_prompt=True, tokenize=False, ) teacher_prompt_tokens: list[int] = tokenizer.encode( @@ -302,6 +305,7 @@ async def _build_sample_datum( teacher_full_tokens = teacher_prompt_tokens + response_tokens teacher_prompt_len = len(teacher_prompt_tokens) teacher_full = T.ModelInput.from_ints(teacher_full_tokens) + teacher_scored_text = tokenizer.decode(teacher_full_tokens, skip_special_tokens=False) # ── Compute teacher logprobs (base model = self-distillation) ── teacher_logprobs_full = await teacher_sampling.compute_logprobs_async(teacher_full) @@ -346,7 +350,7 @@ async def _build_sample_datum( adv_abs_mean = sum(abs(a) for a in advantages) / max(len(advantages), 1) kl_mean = sum(raw_kl_deltas) / max(len(raw_kl_deltas), 1) - metrics = { + metrics: dict[str, float | str] = { "completion_len": completion_len, "effective_kl_coef": effective_kl_coef, "kl_gain": gain, @@ -354,6 +358,7 @@ async def _build_sample_datum( "adv_abs_mean": adv_abs_mean, "kl_mean": kl_mean, "adv_abs_mean_raw": adv_abs_mean_raw, + "teacher_scored_text": teacher_scored_text, } return datum, metrics diff --git a/claas/training/teacher_helpers.py b/claas/training/teacher_helpers.py index 85a93ea..5f0d8ab 100644 --- a/claas/training/teacher_helpers.py +++ b/claas/training/teacher_helpers.py @@ -1,43 +1,51 @@ """Pure helper functions for teacher prompt building. These functions have no Modal dependency and can be imported anywhere. + +Reference: + Kleine Buening et al. (2026), "Aligning Language Models from User Interactions" + https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py """ from __future__ import annotations -from claas.core.config import DEFAULT_SYSTEM_PROMPT from claas.core.types import ChatMessage def build_teacher_messages( prompt: str, feedback: str | None = None, - system_prompt: str | None = None, + *, + system_prompt: str, ) -> list[ChatMessage]: - """Build chat messages for teacher prompt (veRL-compatible template). + """Build chat messages for the hindsight policy (SDPO teacher). + + Appends a hindsight context block to the user prompt so the frozen base + model can score the student's response tokens with awareness of the user's + follow-up. Template follows the online SDPO trainer from: + https://github.com/lasgroup/user_interactions/blob/main/online_sdpo_trainer.py - Template matches veRL's reprompt_template structure: - {prompt}{feedback}\\n\\nCorrectly solve the original question. + The system_prompt must match the one the student saw so the teacher scores + under the same context. Args: prompt: The original user prompt - feedback: Optional feedback about the response quality - system_prompt: Optional system prompt + feedback: Optional user follow-up / feedback about the response + system_prompt: System prompt (must match the student's) Returns: List of message dicts with 'role' and 'content' keys """ - if system_prompt is None: - system_prompt = DEFAULT_SYSTEM_PROMPT messages: list[ChatMessage] = [{"role": "system", "content": system_prompt}] - # Build user content with veRL-style template if feedback: - feedback_section = ( - "\n\nThe following is positive or negative feedback from your earlier attempt:" - f"\n\n{feedback}\n" + block = ( + "\n\n=== HINDSIGHT CONTEXT ===\n" + "[The following is a future user message. " + "Use this to guide your answer to the user prompt.]\n" + f"{feedback.strip()}" ) - user_content = f"{prompt}{feedback_section}\n\nCorrectly solve the original question.\n" + user_content = f"{prompt}{block}" else: user_content = prompt diff --git a/docker/scripts/init-stack.py b/docker/scripts/init-stack.py index 35fe431..cf654b8 100755 --- a/docker/scripts/init-stack.py +++ b/docker/scripts/init-stack.py @@ -260,6 +260,7 @@ def write_openclaw_config() -> None: "port": 18789, "mode": "local", "bind": "loopback", + "controlUi": {"enabled": False}, "http": { "endpoints": { "chatCompletions": {"enabled": True}, diff --git a/tests/integration/run_tinker_stack_integration.sh b/tests/integration/run_tinker_stack_integration.sh index 54cca83..3e36791 100755 --- a/tests/integration/run_tinker_stack_integration.sh +++ b/tests/integration/run_tinker_stack_integration.sh @@ -27,7 +27,13 @@ compose() { docker compose -f "$COMPOSE_FILE" --profile tinker "$@" } +dump_logs() { + echo "==> Docker container logs (last 200 lines each) ..." + compose logs --tail=200 || true +} + cleanup() { + dump_logs echo "==> Tearing down tinker stack ..." compose down -v || true } diff --git a/tests/integration/test_local_engine_integration.py b/tests/integration/test_local_engine_integration.py index aefbd4a..06e945c 100644 --- a/tests/integration/test_local_engine_integration.py +++ b/tests/integration/test_local_engine_integration.py @@ -107,6 +107,7 @@ def test_local_engine_integration_paths(monkeypatch, tmp_path): prompt_token_ids=[1, 2], response_token_ids=[3], user_prompt="prompt", + system_prompt="You are a helpful assistant.", ) ], ) @@ -158,6 +159,7 @@ def test_local_engine_cleanup_failure_propagates(monkeypatch): prompt_token_ids=[1, 2], response_token_ids=[3], user_prompt="prompt", + system_prompt="You are a helpful assistant.", ) ], ) diff --git a/tests/integration/test_tinker_stack_integration.py b/tests/integration/test_tinker_stack_integration.py index aa18f46..926aa6e 100644 --- a/tests/integration/test_tinker_stack_integration.py +++ b/tests/integration/test_tinker_stack_integration.py @@ -112,7 +112,11 @@ def test_full_lifecycle(self, tinker_stack: TinkerStack) -> None: assert lora_id in list_resp.json()["loras"] # 3. Inference through CLaaS API - chat_messages = [{"role": "user", "content": user_prompt}] + system_prompt = "You are a helpful assistant." + chat_messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] chat_payload = { "model": tinker_stack.model, "messages": chat_messages, @@ -136,7 +140,7 @@ def test_full_lifecycle(self, tinker_stack: TinkerStack) -> None: # 4. Distill via feedback endpoint (teacher_mode=self) # The API resolves prompt, response, and logprobs from # the completion cache by hashing the response text. - teacher_messages = build_teacher_messages(user_prompt, feedback_text) + teacher_messages = build_teacher_messages(user_prompt, feedback_text, system_prompt=system_prompt) logger.info( "Teacher messages (built by engine for self-distillation):\n%s", json.dumps(teacher_messages, indent=2), diff --git a/tests/test_api.py b/tests/test_api.py index 3ca00c5..678ad81 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -108,7 +108,15 @@ def __init__(self): self.remote = _RemoteCallFailure() -def _seed_cache(visible_response: str, *, logprobs: list[float] | None = None) -> None: +_DEFAULT_TEST_SYSTEM_PROMPT = "You are a helpful assistant." + + +def _seed_cache( + visible_response: str, + *, + logprobs: list[float] | None = None, + system_prompt: str = _DEFAULT_TEST_SYSTEM_PROMPT, +) -> None: """Pre-populate the completion cache keyed by normalized SHA-256 of visible_response.""" content_hash = hashlib.sha256(normalize_for_hash(visible_response).encode("utf-8")).hexdigest() completion_cache.put( @@ -119,6 +127,7 @@ def _seed_cache(visible_response: str, *, logprobs: list[float] | None = None) - response_token_ids=[10, 20], prompt_token_ids=[1, 2, 3], response_logprobs=logprobs if logprobs is not None else [-0.1, -0.2], + system_prompt=system_prompt, ), ) @@ -455,6 +464,83 @@ async def fake_run_distill(payload): assert sample["response"] == "cached-response" +def test_feedback_system_prompt_flows_from_cache(monkeypatch, tmp_path): + """system_prompt is resolved from the completion cache into DistillBatchItem.""" + from claas import api + + _mock_config(monkeypatch, "tinker") + _seed_cache("sys prompt response", system_prompt="You are a pirate assistant.") + captured_payload = {} + + class _Engine: + async def lora_exists(self, _lora_id): + return LoraExistsPayload(exists=True) + + async def fake_run_distill(payload): + captured_payload["samples"] = [s.model_dump() for s in payload.samples] + return DistillResponse(lora_id="user/model", metadata={}) + + monkeypatch.setattr(api, "_get_training_engine", lambda: _Engine()) + monkeypatch.setattr(api, "_run_distill", fake_run_distill) + monkeypatch.setattr(feedback_log_mod, "write_feedback_log", lambda _r, _d: str(tmp_path / "log.json")) + + client = TestClient(web_app) + response = client.post( + "/v1/feedback", + json={ + "requests": [ + { + "lora_id": "user/model", + "prompt": "clean prompt", + "response": "sys prompt response", + "feedback": "good", + "training": {}, + } + ], + "orchestration": {"sleep_before": False, "wake_after": False}, + }, + ) + + assert response.status_code == 200 + sample = captured_payload["samples"][0] + assert sample["system_prompt"] == "You are a pirate assistant." + + +def test_feedback_missing_system_prompt_returns_422(monkeypatch, tmp_path): + """Cache entry with empty system_prompt returns 422.""" + from claas import api + + _mock_config(monkeypatch, "tinker") + _seed_cache("no sys prompt response", system_prompt="") + + class _Engine: + async def lora_exists(self, _lora_id): + return LoraExistsPayload(exists=True) + + monkeypatch.setattr(api, "_get_training_engine", lambda: _Engine()) + monkeypatch.setattr(feedback_log_mod, "write_feedback_log", lambda _r, _d: str(tmp_path / "log.json")) + + client = TestClient(web_app) + response = client.post( + "/v1/feedback", + json={ + "requests": [ + { + "lora_id": "user/model", + "prompt": "clean prompt", + "response": "no sys prompt response", + "feedback": "good", + "training": {}, + } + ], + "orchestration": {"sleep_before": False, "wake_after": False}, + }, + ) + + assert response.status_code == 422 + assert "no system prompt" in response.json()["detail"] + + def test_feedback_cache_miss_returns_404(monkeypatch, tmp_path): """Missing cache entry returns 404 before any orchestration.""" from claas import api @@ -505,6 +591,7 @@ def test_feedback_missing_logprobs_returns_422(monkeypatch, tmp_path): response_token_ids=[10], prompt_token_ids=[1], response_logprobs=None, + system_prompt="You are a helpful assistant.", ), ) @@ -834,7 +921,8 @@ def test_dashboard_renders_latest_records(monkeypatch, tmp_path): "response_logprobs": [-0.1], "prompt_token_ids": [1], "response_token_ids": [2], - "user_prompt": "prompt-a" + "user_prompt": "prompt-a", + "system_prompt": "You are a helpful assistant." } ], "vllm": { @@ -917,7 +1005,8 @@ def test_dashboard_truncates_prompt_preview(monkeypatch, tmp_path): "response_logprobs": [-0.1], "prompt_token_ids": [1], "response_token_ids": [2], - "user_prompt": "{long_prompt}" + "user_prompt": "{long_prompt}", + "system_prompt": "You are a helpful assistant." }} ], "vllm": {{ @@ -985,7 +1074,8 @@ def test_dashboard_renders_one_row_per_batch_item(monkeypatch, tmp_path): "response_logprobs": [-0.1], "prompt_token_ids": [1], "response_token_ids": [2], - "user_prompt": "prompt-d1" + "user_prompt": "prompt-d1", + "system_prompt": "You are a helpful assistant." }, { "prompt": "prompt-d2", @@ -994,7 +1084,8 @@ def test_dashboard_renders_one_row_per_batch_item(monkeypatch, tmp_path): "response_logprobs": [-0.2], "prompt_token_ids": [1], "response_token_ids": [2], - "user_prompt": "prompt-d2" + "user_prompt": "prompt-d2", + "system_prompt": "You are a helpful assistant." } ], "vllm": { @@ -1034,6 +1125,132 @@ def test_dashboard_renders_one_row_per_batch_item(monkeypatch, tmp_path): assert "2 samples" in response.text +def test_dashboard_renders_teacher_scored_text(monkeypatch, tmp_path): + """Dashboard shows Teacher Scored Text when present in distill metadata.""" + _mock_config(monkeypatch, "local", feedback_log_dir=str(tmp_path)) + (tmp_path / "20240101T000010-t.json").write_text( + """ +{ + "request_id": "t", + "timestamp_utc": "2024-01-01T00:00:10Z", + "status": "ok", + "phase": "done", + "lora_id": "user/model", + "requests": [ + { + "lora_id": "user/model", + "prompt": "prompt-t", + "response": "response-t", + "feedback": "feedback-t" + } + ], + "batch_samples": [ + { + "prompt": "prompt-t", + "response": "response-t", + "feedback": "feedback-t", + "response_logprobs": [-0.1], + "prompt_token_ids": [1], + "response_token_ids": [2], + "user_prompt": "prompt-t", + "system_prompt": "You are a helpful assistant." + } + ], + "vllm": { + "slept": false, + "woke": false + }, + "timing_ms": { + "sleep": 0, + "distill": 2, + "save": 0, + "wake": 0, + "logprobs": 0, + "total": 2 + }, + "distill_result": { + "lora_id": "user/model", + "metadata": { + "loss": 0.1, + "teacher_scored_texts": ["<|im_start|>system\\nYou are helpful<|im_end|>\\n<|im_start|>user\\nWhat is 2+2?<|im_end|>\\n<|im_start|>assistant\\nFour"] + } + }, + "error": null +} +""".strip(), + encoding="utf-8", + ) + + client = TestClient(web_app) + response = client.get("/v1/dashboard") + + assert response.status_code == 200 + assert "Teacher Scored Text" in response.text + assert "What is 2+2?" in response.text + + +def test_dashboard_omits_teacher_text_when_absent(monkeypatch, tmp_path): + """Dashboard gracefully handles old logs without teacher_scored_texts.""" + _mock_config(monkeypatch, "local", feedback_log_dir=str(tmp_path)) + (tmp_path / "20240101T000011-u.json").write_text( + """ +{ + "request_id": "u", + "timestamp_utc": "2024-01-01T00:00:11Z", + "status": "ok", + "phase": "done", + "lora_id": "user/model", + "requests": [ + { + "lora_id": "user/model", + "prompt": "prompt-u", + "response": "response-u", + "feedback": "feedback-u" + } + ], + "batch_samples": [ + { + "prompt": "prompt-u", + "response": "response-u", + "feedback": "feedback-u", + "response_logprobs": [-0.1], + "prompt_token_ids": [1], + "response_token_ids": [2], + "user_prompt": "prompt-u", + "system_prompt": "You are a helpful assistant." + } + ], + "vllm": { + "slept": false, + "woke": false + }, + "timing_ms": { + "sleep": 0, + "distill": 2, + "save": 0, + "wake": 0, + "logprobs": 0, + "total": 2 + }, + "distill_result": { + "lora_id": "user/model", + "metadata": { + "loss": 0.1 + } + }, + "error": null +} +""".strip(), + encoding="utf-8", + ) + + client = TestClient(web_app) + response = client.get("/v1/dashboard") + + assert response.status_code == 200 + assert "Teacher Scored Text" not in response.text + + def test_feedback_recent_route_is_removed(): client = TestClient(web_app) response = client.get("/v1/feedback/recent") diff --git a/tests/test_local_training_engine.py b/tests/test_local_training_engine.py index 0a2d764..3d72027 100644 --- a/tests/test_local_training_engine.py +++ b/tests/test_local_training_engine.py @@ -54,6 +54,7 @@ def test_local_engine_distill_propagates_cleanup_error(monkeypatch): prompt_token_ids=[1, 2], response_token_ids=[3], user_prompt="p", + system_prompt="You are a helpful assistant.", ) ], ) diff --git a/tests/test_teacher_helpers.py b/tests/test_teacher_helpers.py index 295b2b9..1aa9223 100644 --- a/tests/test_teacher_helpers.py +++ b/tests/test_teacher_helpers.py @@ -2,6 +2,8 @@ from __future__ import annotations +import pytest + from claas.core.types import ChatMessage from claas.training.teacher_helpers import ( build_teacher_messages, @@ -11,30 +13,32 @@ class TestBuildTeacherMessages: def test_with_feedback(self): - msgs = build_teacher_messages("What is 2+2?", "Too verbose") + msgs = build_teacher_messages("What is 2+2?", "Too verbose", system_prompt="You are helpful.") assert len(msgs) == 2 assert msgs[0]["role"] == "system" + assert msgs[0]["content"] == "You are helpful." assert msgs[1]["role"] == "user" assert "What is 2+2?" in msgs[1]["content"] assert "Too verbose" in msgs[1]["content"] - assert "Correctly solve the original question" in msgs[1]["content"] + assert "=== HINDSIGHT CONTEXT ===" in msgs[1]["content"] + assert "Use this to guide your answer to the user prompt" in msgs[1]["content"] def test_without_feedback(self): - msgs = build_teacher_messages("What is 2+2?") + msgs = build_teacher_messages("What is 2+2?", system_prompt="You are helpful.") assert len(msgs) == 2 assert msgs[1]["content"] == "What is 2+2?" def test_without_feedback_explicit_none(self): - msgs = build_teacher_messages("Hello", None) + msgs = build_teacher_messages("Hello", None, system_prompt="You are helpful.") assert msgs[1]["content"] == "Hello" - def test_custom_system_prompt(self): - msgs = build_teacher_messages("x", system_prompt="Be concise.") - assert msgs[0]["content"] == "Be concise." - - def test_default_system_prompt(self): - msgs = build_teacher_messages("x") - assert "helpful assistant" in msgs[0]["content"] + @pytest.mark.parametrize("system_prompt", [ + "Be concise.", + "You are a pirate assistant who speaks in pirate lingo.", + ]) + def test_system_prompt_matches_input(self, system_prompt: str): + msgs = build_teacher_messages("x", system_prompt=system_prompt) + assert msgs[0]["content"] == system_prompt class TestTeacherMessagesToChatTemplate: @@ -50,5 +54,3 @@ def test_converts_typed_messages(self): ] # Returns new list (not the same object) assert result is not msgs - - diff --git a/tests/test_tinker_engine.py b/tests/test_tinker_engine.py index b0818dc..7dd35ca 100644 --- a/tests/test_tinker_engine.py +++ b/tests/test_tinker_engine.py @@ -84,6 +84,7 @@ def mock_tokenizer(): tok = MagicMock() tok.encode.side_effect = lambda text, **kw: list(range(len(text))) tok.apply_chat_template.return_value = "chat-template-output" + tok.decode.side_effect = lambda ids, **kw: f"decoded:{ids}" return tok @@ -417,6 +418,7 @@ def test_engine_distill_full_flow(tinker_engine, mock_training_client): prompt_token_ids=[0, 1, 2, 3, 4], response_token_ids=[0, 1, 2, 3, 4], user_prompt="Hello", + system_prompt="You are a helpful assistant.", ) ], ) @@ -436,6 +438,8 @@ def test_engine_distill_full_flow(tinker_engine, mock_training_client): assert "adv_mean" in result.metadata assert result.metadata["loss_fn"] == "importance_sampling" assert result.metadata["tinker_fwd_metrics"] == {"loss": 0.42} + assert "teacher_scored_texts" in result.metadata + assert len(result.metadata["teacher_scored_texts"]) == 1 # Verify the training client was restored from the init checkpoint mock_service.create_training_client_from_state_async.assert_called_once_with( @@ -493,6 +497,7 @@ def test_engine_distill_uses_provided_response_logprobs(tinker_engine, mock_trai prompt_token_ids=[0, 1, 2, 3, 4], response_token_ids=[0, 1, 2, 3, 4], user_prompt="Hello", + system_prompt="You are a helpful assistant.", ) ], ) @@ -538,6 +543,7 @@ def test_engine_distill_batch_multiple_samples(tinker_engine, mock_training_clie prompt_token_ids=[0, 1, 2, 3, 4], response_token_ids=[0, 1, 2, 3, 4], user_prompt="Hello", + system_prompt="You are a helpful assistant.", ), # Sample 2: 5 tokens DistillBatchItem( @@ -548,6 +554,7 @@ def test_engine_distill_batch_multiple_samples(tinker_engine, mock_training_clie prompt_token_ids=[0, 1], response_token_ids=[0, 1, 2, 3, 4], user_prompt="Hi", + system_prompt="You are a helpful assistant.", ), # Sample 3: 3 tokens DistillBatchItem( @@ -558,6 +565,7 @@ def test_engine_distill_batch_multiple_samples(tinker_engine, mock_training_clie prompt_token_ids=[0, 1, 2], response_token_ids=[0, 1, 2], user_prompt="Hey", + system_prompt="You are a helpful assistant.", ), ] @@ -579,6 +587,10 @@ def test_engine_distill_batch_multiple_samples(tinker_engine, mock_training_clie # 5 + 5 + 3 = 13 assert result.metadata["completion_len"] == 13 + # Teacher scored texts: one per sample + assert "teacher_scored_texts" in result.metadata + assert len(result.metadata["teacher_scored_texts"]) == 3 + # Averaged metrics are present assert "adv_mean" in result.metadata assert "adv_abs_mean" in result.metadata @@ -713,6 +725,7 @@ def test_engine_distill_uses_response_token_ids(tinker_engine, mock_training_cli prompt_token_ids=prompt_ids, response_token_ids=response_ids, user_prompt="What is 2+2?", + system_prompt="You are a helpful assistant.", ) ], ) @@ -757,13 +770,14 @@ def test_engine_distill_uses_user_prompt_for_teacher(tinker_engine, mock_trainin prompt_token_ids=[10, 20, 30], response_token_ids=[40, 50, 60, 70], user_prompt="What is 2+2?", + system_prompt="You are a pirate.", ) ], ) with patch("claas.training.engine.tinker.engine.build_teacher_messages") as mock_build: mock_build.return_value = [ - {"role": "system", "content": "You are helpful"}, + {"role": "system", "content": "You are a pirate."}, {"role": "user", "content": "What is 2+2?"}, ] asyncio.run(engine.distill(payload)) @@ -772,3 +786,11 @@ def test_engine_distill_uses_user_prompt_for_teacher(tinker_engine, mock_trainin called_prompt = mock_build.call_args[0][0] assert called_prompt == "What is 2+2?" assert "<|im_start|>" not in called_prompt + # system_prompt is forwarded to the teacher + assert mock_build.call_args[1]["system_prompt"] == "You are a pirate." + + # Teacher prompt must use add_generation_prompt=True to include the + # assistant turn delimiter before response tokens. + tok = mock_training_client.get_tokenizer() + tok.apply_chat_template.assert_called_once() + assert tok.apply_chat_template.call_args[1]["add_generation_prompt"] is True