diff --git a/claas/api.py b/claas/api.py index 5284ed9..12f38c4 100644 --- a/claas/api.py +++ b/claas/api.py @@ -36,6 +36,7 @@ from __future__ import annotations import asyncio +import dataclasses import hashlib import logging import os @@ -86,6 +87,7 @@ ServiceHealth, TextCompletionChoice, TextCompletionResponse, + TrainingConfig, ) from .dashboard import feedback_log as feedback_log_mod, rendering as dashboard_rendering from .inference import get_inference_backend, vllm_control @@ -215,6 +217,28 @@ def _get_inference_backend(request: Request) -> InferenceBackend: return request.app.state.inference_backend +def _validate_training_config(training: TrainingConfig) -> None: + """Validate training config ranges for direct API callers.""" + errors: list[str] = [] + if training.learning_rate <= 0: + errors.append("learning_rate must be > 0") + if not (0.0 <= training.alpha <= 1.0): + errors.append("alpha must be within [0, 1]") + if not (1.0 <= training.is_clip <= 20.0): + errors.append("is_clip must be within [1, 20]") + if training.max_grad_norm < 0.0: + errors.append("max_grad_norm must be >= 0") + if not (0.0 <= training.kl_reg_weight <= 1.0): + errors.append("kl_reg_weight must be within [0, 1]") + if not (10 <= training.teacher_top_k <= 100): + errors.append("teacher_top_k must be within [10, 100]") + if errors: + raise HTTPException( + status_code=422, + detail=f"invalid training config: {'; '.join(errors)}", + ) + + # --------------------------------------------------------------------------- # Inference endpoints # --------------------------------------------------------------------------- @@ -372,12 +396,14 @@ async def feedback(request: FeedbackBatchRequest) -> FeedbackResponse: first_request = batch_requests[0] lora_id = first_request.lora_id - training_ref = first_request.training.model_dump(mode="json") + _validate_training_config(first_request.training) + training_ref = dataclasses.asdict(first_request.training) for req in batch_requests[1:]: + _validate_training_config(req.training) if req.lora_id != lora_id: raise HTTPException(status_code=400, detail="all requests must use the same lora_id") - if req.training.model_dump(mode="json") != training_ref: + if dataclasses.asdict(req.training) != training_ref: raise HTTPException(status_code=400, detail="all requests must use the same training config") # Resolve cache entries before acquiring lock or doing orchestration. diff --git a/claas/core/types.py b/claas/core/types.py index 6a81f8f..0b9e106 100644 --- a/claas/core/types.py +++ b/claas/core/types.py @@ -6,6 +6,7 @@ from __future__ import annotations +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, TypedDict from pydantic import BaseModel, ConfigDict, Field @@ -21,42 +22,16 @@ class ChatMessage(TypedDict): content: str -class TrainingConfig(BaseModel): - """Training configuration for distillation.""" +@dataclass +class TrainingConfig: + """Training hyperparameters (dataclass for Hydra structured-config compatibility).""" - learning_rate: float = Field( - default=3e-5, - description="Learning rate for LoRA parameter updates", - ) - alpha: float = Field( - default=0.5, - ge=0.0, - le=1.0, - description="GJS interpolation (0.5 = symmetric JSD, 1.0 = reverse KL)", - ) - is_clip: float = Field( - default=5.0, - ge=1.0, - le=20.0, - description="Importance sampling ratio clip (exp space)", - ) - max_grad_norm: float = Field( - default=1.0, - ge=0.0, - description="Maximum gradient norm for clipping", - ) - kl_reg_weight: float = Field( - default=0.0, - ge=0.0, - le=1.0, - description="Weight for KL regularization to base policy", - ) - teacher_top_k: int = Field( - default=100, - ge=10, - le=100, - description="Number of top logprobs to request from teacher", - ) + learning_rate: float = 3e-5 + alpha: float = 0.5 + is_clip: float = 5.0 + max_grad_norm: float = 1.0 + kl_reg_weight: float = 0.0 + teacher_top_k: int = 100 class SDPOLossInput(BaseModel): @@ -473,5 +448,3 @@ class TextCompletionResponse(BaseModel): model: str choices: list[TextCompletionChoice] usage: CompletionUsage - - diff --git a/claas/eval/README.md b/claas/eval/README.md index 242d80a..92438fe 100644 --- a/claas/eval/README.md +++ b/claas/eval/README.md @@ -26,13 +26,20 @@ metrics: # metrics to evaluate per step num_steps: 20 batch_size: 4 -steps_per_batch: 1 # gradient updates per batch +steps_per_batch: 4 # gradient updates per batch feedback_repetitions: 1 # times to repeat feedback string +training: # forwarded to /v1/feedback training config + learning_rate: 3e-5 + alpha: 0.5 + is_clip: 5.0 + max_grad_norm: 1.0 + kl_reg_weight: 0.0 + teacher_top_k: 100 collapse_steps: [0, 5, 10, 15, 19] # steps where collapse metric runs plots: true # generate matplotlib plots seed: 42 lora_id_prefix: eval -output_dir: ./data/evals +output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ} openclaw_url: http://localhost:18789 # OpenClaw gateway (null = use CLaaS API directly) ``` @@ -48,6 +55,9 @@ uv run python -m claas.eval 'preferences=[concise]' num_steps=10 # Override base model and mode uv run python -m claas.eval base_model=Qwen/Qwen3-30B-A3B mode=tinker +# Override training hyperparameters +uv run python -m claas.eval training.is_clip=7.0 training.learning_rate=1e-4 + # Use a custom config directory uv run python -m claas.eval --config-dir ./my_configs --config-name my_config ``` @@ -55,16 +65,14 @@ uv run python -m claas.eval --config-dir ./my_configs --config-name my_config ### Programmatic usage ```python -from claas.eval.config import build_harness_config from claas.eval.runner import run_harness from claas.eval.types import EvalConfig import asyncio -config = build_harness_config( - EvalConfig( - preferences=["concise"], - num_steps=5, - ) +config = EvalConfig( + preferences=["concise"], + num_steps=5, + output_dir="./data/evals/manual-run", # explicit when bypassing Hydra CLI ) asyncio.run(run_harness(config)) ``` diff --git a/claas/eval/__main__.py b/claas/eval/__main__.py index b174efb..a2344d2 100644 --- a/claas/eval/__main__.py +++ b/claas/eval/__main__.py @@ -13,7 +13,7 @@ import hydra from omegaconf import OmegaConf -from .config import build_harness_config +from . import config as _config # noqa: F401 from .types import EvalConfig @@ -25,8 +25,7 @@ def main(cfg: EvalConfig) -> None: if not isinstance(eval_cfg, EvalConfig): raise TypeError("Hydra did not produce an EvalConfig instance") - config = build_harness_config(eval_cfg) - asyncio.run(run_harness(config)) + asyncio.run(run_harness(eval_cfg)) if __name__ == "__main__": diff --git a/claas/eval/config.py b/claas/eval/config.py index e6424bc..bd912de 100644 --- a/claas/eval/config.py +++ b/claas/eval/config.py @@ -1,32 +1,7 @@ -"""Hydra-based configuration for the evaluation harness.""" - -from __future__ import annotations - -import dataclasses -import os -import re -from datetime import datetime, timezone -from pathlib import Path +"""Hydra schema registration for the evaluation harness.""" from hydra.core.config_store import ConfigStore -from .types import EvalConfig, HarnessConfig - -# Pattern matching the timestamped run-id suffix (e.g. 20260220-012345Z) -_RUN_ID_RE = re.compile(r"\d{8}-\d{6}Z$") +from .types import EvalConfig ConfigStore.instance().store(name="_eval_schema", node=EvalConfig) - - -def build_harness_config(eval_cfg: EvalConfig) -> HarnessConfig: - """Post-process EvalConfig → HarnessConfig (no secrets).""" - fields = dataclasses.asdict(eval_cfg) - - # Timestamped output subdir (skip if output_dir already ends with a run-id, - # which allows resuming an existing run by passing its directory). - output_dir = fields["output_dir"] - if not _RUN_ID_RE.search(Path(output_dir).name): - run_id = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%SZ") - fields["output_dir"] = os.path.join(output_dir, run_id) - - return HarnessConfig(**fields) diff --git a/claas/eval/configs/base.yaml b/claas/eval/configs/base.yaml index 92323c1..c16f34d 100644 --- a/claas/eval/configs/base.yaml +++ b/claas/eval/configs/base.yaml @@ -24,8 +24,15 @@ num_steps: 20 batch_size: 4 steps_per_batch: 4 feedback_repetitions: 1 +training: + learning_rate: 3e-5 + alpha: 0.5 + is_clip: 5.0 + max_grad_norm: 1.0 + kl_reg_weight: 0.0 + teacher_top_k: 100 seed: 42 lora_id_prefix: eval -output_dir: ./data/evals +output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ} openclaw_url: http://localhost:18789 diff --git a/claas/eval/runner.py b/claas/eval/runner.py index c449d0d..ac59041 100644 --- a/claas/eval/runner.py +++ b/claas/eval/runner.py @@ -23,10 +23,10 @@ from .plotting import generate_plots from .preferences import PreferenceConfig, get_preference_configs from .types import ( + EvalConfig, EvalMetrics, ExperimentResult, ExperimentSummary, - HarnessConfig, LocalDistillMetrics, MetricContext, StepResult, @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) -async def _init_lora(config: HarnessConfig, lora_id: str) -> str: +async def _init_lora(config: EvalConfig, lora_id: str) -> str: """Initialize a fresh LoRA adapter via CLaaS API.""" # LoRA init can exceed two minutes when the remote trainer is cold-starting. async with httpx.AsyncClient(base_url=config.claas_url, timeout=300.0) as client: @@ -51,7 +51,7 @@ async def _init_lora(config: HarnessConfig, lora_id: str) -> str: async def _submit_feedback( - config: HarnessConfig, + config: EvalConfig, lora_id: str, samples: list[FeedbackItem], ) -> LocalDistillMetrics | TinkerDistillMetrics | None: @@ -92,7 +92,7 @@ async def _submit_feedback( async def _generate_response( - config: HarnessConfig, + config: EvalConfig, model: str, prompt: str, temperature: float = 0, @@ -127,7 +127,7 @@ async def _generate_response( async def _measure_eval_metrics( - config: HarnessConfig, + config: EvalConfig, pref: PreferenceConfig, model_name: str, step: int, @@ -205,7 +205,7 @@ def _append_step_jsonl(output_dir: str, preference: str, step: StepResult) -> No f.write(json.dumps(data) + "\n") -def _write_metadata(output_dir: str, preference: str, config: HarnessConfig, lora_id: str) -> None: +def _write_metadata(output_dir: str, preference: str, config: EvalConfig, lora_id: str) -> None: """Write experiment metadata.""" pref_dir = os.path.join(output_dir, preference) os.makedirs(pref_dir, exist_ok=True) @@ -273,7 +273,7 @@ def _write_summary(output_dir: str, results: list[ExperimentResult]) -> None: async def run_preference_experiment( - config: HarnessConfig, + config: EvalConfig, pref: PreferenceConfig, enabled_metrics: list[Metric] | None = None, needs_generation: bool = False, @@ -357,6 +357,8 @@ async def run_preference_experiment( ) # Main loop + training_cfg = config.training + for step in range(resume_from, config.num_steps): step_start = time.perf_counter() @@ -385,6 +387,7 @@ async def run_preference_experiment( prompt=prompt, response=content, feedback=feedback_str, + training=training_cfg, )) except (httpx.HTTPError, KeyError, ValueError) as e: logger.warning( @@ -471,7 +474,7 @@ async def run_preference_experiment( return result -async def run_harness(config: HarnessConfig) -> None: +async def run_harness(config: EvalConfig) -> None: """Run the full evaluation harness.""" logging.basicConfig( level=logging.INFO, diff --git a/claas/eval/types.py b/claas/eval/types.py index 84d1733..02978dc 100644 --- a/claas/eval/types.py +++ b/claas/eval/types.py @@ -7,7 +7,7 @@ from typing import Any, Optional from claas.core.config import DEFAULT_SYSTEM_PROMPT -from claas.core.types import ChatMessage +from claas.core.types import ChatMessage, TrainingConfig @dataclass @@ -96,10 +96,7 @@ class EvalConfig: batch_size: int = 4 steps_per_batch: int = 4 feedback_repetitions: int = 1 - - -# HarnessConfig is the post-processed runtime config (still no secrets). -HarnessConfig = EvalConfig + training: TrainingConfig = field(default_factory=TrainingConfig) @dataclass diff --git a/tests/test_api.py b/tests/test_api.py index 678ad81..cc81f3e 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -272,6 +272,60 @@ async def fake_wait_idle(base_url, api_key, timeout_s): assert log_records and log_records[0]["status"] == "error" +def test_feedback_rejects_invalid_training_config(monkeypatch): + _mock_config(monkeypatch, "modal") + client = TestClient(web_app) + + response = client.post( + "/v1/feedback", + json={ + "requests": [ + { + "lora_id": "user/model", + "prompt": "clean prompt", + "response": "unused", + "feedback": "f", + "training": {"is_clip": -1}, + } + ] + }, + ) + + assert response.status_code == 422 + assert "invalid training config" in response.json()["detail"] + assert "is_clip" in response.json()["detail"] + + +def test_feedback_rejects_mismatched_training_config(monkeypatch): + _mock_config(monkeypatch, "modal") + client = TestClient(web_app) + + response = client.post( + "/v1/feedback", + json={ + "requests": [ + { + "lora_id": "user/model", + "prompt": "prompt-1", + "response": "response-1", + "feedback": "f", + "training": {"is_clip": 5.0}, + }, + { + "lora_id": "user/model", + "prompt": "prompt-2", + "response": "response-2", + "feedback": "f", + "training": {"is_clip": 7.0}, + }, + ] + }, + ) + + assert response.status_code == 400 + assert response.json()["detail"] == "all requests must use the same training config" + + def test_export_404_when_missing(monkeypatch): from claas import api diff --git a/tests/test_eval_config.py b/tests/test_eval_config.py index 2179bee..930ce8e 100644 --- a/tests/test_eval_config.py +++ b/tests/test_eval_config.py @@ -8,12 +8,14 @@ import pytest from hydra import compose, initialize_config_dir +from hydra.errors import ConfigCompositionException from omegaconf import OmegaConf -from claas.eval.config import build_harness_config +from claas.core.types import TrainingConfig +from claas.eval import config as _eval_config # noqa: F401 from claas.eval.metrics import VerifierResult from claas.eval.preferences import get_preference_configs -from claas.eval.types import EvalConfig, HarnessConfig +from claas.eval.types import EvalConfig def _make_config_dir(yaml_content: str) -> str: @@ -44,20 +46,33 @@ def _compose_eval_config( def test_hydra_config_defaults() -> None: - config = build_harness_config(_compose_eval_config()) - assert isinstance(config, HarnessConfig) + config = _compose_eval_config() + assert isinstance(config, EvalConfig) + assert isinstance(config.training, TrainingConfig) assert config.mode == "tinker" assert config.num_steps == 20 assert config.batch_size == 4 assert config.seed == 42 assert config.lora_id_prefix == "eval" assert config.plots is True + assert config.training.is_clip == 5.0 + assert config.training.learning_rate == 3e-5 + assert config.output_dir.startswith("./data/evals/") def test_hydra_config_with_overrides() -> None: - config = build_harness_config(_compose_eval_config(["num_steps=5", "seed=99"])) + config = _compose_eval_config( + ["num_steps=5", "seed=99", "training.is_clip=7.0", "training.learning_rate=1e-4"] + ) assert config.num_steps == 5 assert config.seed == 99 + assert config.training.is_clip == 7.0 + assert config.training.learning_rate == 1e-4 + + +def test_training_type_invariance_rejects_scalar() -> None: + with pytest.raises(ConfigCompositionException): + _compose_eval_config(["training=oops"]) def test_unknown_key_rejected() -> None: @@ -67,24 +82,29 @@ def test_unknown_key_rejected() -> None: def test_timestamped_output_dir() -> None: - config = build_harness_config(_compose_eval_config(["output_dir=/tmp/test-evals"])) - assert config.output_dir.startswith("/tmp/test-evals/") - assert len(config.output_dir) > len("/tmp/test-evals/") + config = _compose_eval_config() + assert config.output_dir.startswith("./data/evals/") + assert len(config.output_dir) > len("./data/evals/") + + +def test_output_dir_override_respected() -> None: + config = _compose_eval_config(["output_dir=/tmp/test-evals"]) + assert config.output_dir == "/tmp/test-evals" def test_collapse_steps_list() -> None: - config = build_harness_config(_compose_eval_config(["collapse_steps=[0,5,10]"])) + config = _compose_eval_config(["collapse_steps=[0,5,10]"]) assert config.collapse_steps == [0, 5, 10] def test_steps_per_batch_from_config() -> None: - config = build_harness_config(_compose_eval_config(["steps_per_batch=3"])) + config = _compose_eval_config(["steps_per_batch=3"]) assert config.steps_per_batch == 3 def test_hydra_config_custom_dir() -> None: tmpdir = _make_config_dir("mode: local\nnum_steps: 3\npreferences:\n - no_emoji\n") - config = build_harness_config(_compose_eval_config(config_dir=tmpdir)) + config = _compose_eval_config(config_dir=tmpdir) assert config.mode == "local" assert config.num_steps == 3 assert config.preferences == ["no_emoji"] diff --git a/tests/test_eval_runner.py b/tests/test_eval_runner.py index fa04c0c..38cfebf 100644 --- a/tests/test_eval_runner.py +++ b/tests/test_eval_runner.py @@ -2,8 +2,9 @@ from __future__ import annotations +from claas.core.types import TrainingConfig from claas.eval.types import ( - HarnessConfig, + EvalConfig, LocalDistillMetrics, StepResult, TinkerDistillMetrics, @@ -104,33 +105,40 @@ def test_step_result_from_dict_no_metrics(): assert result.sdpo_metrics is None -# ── HarnessConfig defaults ─────────────────────────────────────────── +# ── EvalConfig defaults ────────────────────────────────────────────── def test_steps_per_batch_default(): - """HarnessConfig.steps_per_batch defaults to 4.""" - config = HarnessConfig() + """EvalConfig.steps_per_batch defaults to 4.""" + config = EvalConfig() assert config.steps_per_batch == 4 def test_steps_per_batch_custom(): - """HarnessConfig.steps_per_batch can be set.""" - config = HarnessConfig(steps_per_batch=3) + """EvalConfig.steps_per_batch can be set.""" + config = EvalConfig(steps_per_batch=3) assert config.steps_per_batch == 3 def test_feedback_repetitions_default(): - """HarnessConfig.feedback_repetitions defaults to 1.""" - config = HarnessConfig() + """EvalConfig.feedback_repetitions defaults to 1.""" + config = EvalConfig() assert config.feedback_repetitions == 1 def test_feedback_repetitions_custom(): - """HarnessConfig.feedback_repetitions can be set.""" - config = HarnessConfig(feedback_repetitions=4) + """EvalConfig.feedback_repetitions can be set.""" + config = EvalConfig(feedback_repetitions=4) assert config.feedback_repetitions == 4 +def test_training_config_default_type(): + """EvalConfig.training defaults to typed TrainingConfig.""" + config = EvalConfig() + assert isinstance(config.training, TrainingConfig) + assert config.training.is_clip == 5.0 + + # ── StepResult sub_step_count ────────────────────────────────────────