Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions claas/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from __future__ import annotations

import asyncio
import dataclasses
import hashlib
import logging
import os
Expand Down Expand Up @@ -86,6 +87,7 @@
ServiceHealth,
TextCompletionChoice,
TextCompletionResponse,
TrainingConfig,
)
from .dashboard import feedback_log as feedback_log_mod, rendering as dashboard_rendering
from .inference import get_inference_backend, vllm_control
Expand Down Expand Up @@ -215,6 +217,28 @@ def _get_inference_backend(request: Request) -> InferenceBackend:
return request.app.state.inference_backend


def _validate_training_config(training: TrainingConfig) -> None:
"""Validate training config ranges for direct API callers."""
errors: list[str] = []
if training.learning_rate <= 0:
errors.append("learning_rate must be > 0")
if not (0.0 <= training.alpha <= 1.0):
errors.append("alpha must be within [0, 1]")
if not (1.0 <= training.is_clip <= 20.0):
errors.append("is_clip must be within [1, 20]")
if training.max_grad_norm < 0.0:
errors.append("max_grad_norm must be >= 0")
if not (0.0 <= training.kl_reg_weight <= 1.0):
errors.append("kl_reg_weight must be within [0, 1]")
if not (10 <= training.teacher_top_k <= 100):
errors.append("teacher_top_k must be within [10, 100]")
if errors:
raise HTTPException(
status_code=422,
detail=f"invalid training config: {'; '.join(errors)}",
)


# ---------------------------------------------------------------------------
# Inference endpoints
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -372,12 +396,14 @@ async def feedback(request: FeedbackBatchRequest) -> FeedbackResponse:

first_request = batch_requests[0]
lora_id = first_request.lora_id
training_ref = first_request.training.model_dump(mode="json")
_validate_training_config(first_request.training)
training_ref = dataclasses.asdict(first_request.training)

for req in batch_requests[1:]:
_validate_training_config(req.training)
if req.lora_id != lora_id:
raise HTTPException(status_code=400, detail="all requests must use the same lora_id")
if req.training.model_dump(mode="json") != training_ref:
if dataclasses.asdict(req.training) != training_ref:
raise HTTPException(status_code=400, detail="all requests must use the same training config")

# Resolve cache entries before acquiring lock or doing orchestration.
Expand Down
47 changes: 10 additions & 37 deletions claas/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, TypedDict

from pydantic import BaseModel, ConfigDict, Field
Expand All @@ -21,42 +22,16 @@ class ChatMessage(TypedDict):
content: str


class TrainingConfig(BaseModel):
"""Training configuration for distillation."""
@dataclass
class TrainingConfig:
"""Training hyperparameters (dataclass for Hydra structured-config compatibility)."""

learning_rate: float = Field(
default=3e-5,
description="Learning rate for LoRA parameter updates",
)
alpha: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="GJS interpolation (0.5 = symmetric JSD, 1.0 = reverse KL)",
)
is_clip: float = Field(
default=5.0,
ge=1.0,
le=20.0,
description="Importance sampling ratio clip (exp space)",
)
max_grad_norm: float = Field(
default=1.0,
ge=0.0,
description="Maximum gradient norm for clipping",
)
kl_reg_weight: float = Field(
default=0.0,
ge=0.0,
le=1.0,
description="Weight for KL regularization to base policy",
)
teacher_top_k: int = Field(
default=100,
ge=10,
le=100,
description="Number of top logprobs to request from teacher",
)
learning_rate: float = 3e-5
alpha: float = 0.5
is_clip: float = 5.0
max_grad_norm: float = 1.0
kl_reg_weight: float = 0.0
teacher_top_k: int = 100


class SDPOLossInput(BaseModel):
Expand Down Expand Up @@ -473,5 +448,3 @@ class TextCompletionResponse(BaseModel):
model: str
choices: list[TextCompletionChoice]
usage: CompletionUsage


24 changes: 16 additions & 8 deletions claas/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ metrics: # metrics to evaluate per step

num_steps: 20
batch_size: 4
steps_per_batch: 1 # gradient updates per batch
steps_per_batch: 4 # gradient updates per batch
feedback_repetitions: 1 # times to repeat feedback string
training: # forwarded to /v1/feedback training config
learning_rate: 3e-5
alpha: 0.5
is_clip: 5.0
max_grad_norm: 1.0
kl_reg_weight: 0.0
teacher_top_k: 100
collapse_steps: [0, 5, 10, 15, 19] # steps where collapse metric runs
plots: true # generate matplotlib plots
seed: 42
lora_id_prefix: eval
output_dir: ./data/evals
output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ}

openclaw_url: http://localhost:18789 # OpenClaw gateway (null = use CLaaS API directly)
```
Expand All @@ -48,23 +55,24 @@ uv run python -m claas.eval 'preferences=[concise]' num_steps=10
# Override base model and mode
uv run python -m claas.eval base_model=Qwen/Qwen3-30B-A3B mode=tinker

# Override training hyperparameters
uv run python -m claas.eval training.is_clip=7.0 training.learning_rate=1e-4

# Use a custom config directory
uv run python -m claas.eval --config-dir ./my_configs --config-name my_config
```

### Programmatic usage

```python
from claas.eval.config import build_harness_config
from claas.eval.runner import run_harness
from claas.eval.types import EvalConfig
import asyncio

config = build_harness_config(
EvalConfig(
preferences=["concise"],
num_steps=5,
)
config = EvalConfig(
preferences=["concise"],
num_steps=5,
output_dir="./data/evals/manual-run", # explicit when bypassing Hydra CLI
)
asyncio.run(run_harness(config))
```
Expand Down
5 changes: 2 additions & 3 deletions claas/eval/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import hydra
from omegaconf import OmegaConf

from .config import build_harness_config
from . import config as _config # noqa: F401
from .types import EvalConfig


Expand All @@ -25,8 +25,7 @@ def main(cfg: EvalConfig) -> None:
if not isinstance(eval_cfg, EvalConfig):
raise TypeError("Hydra did not produce an EvalConfig instance")

config = build_harness_config(eval_cfg)
asyncio.run(run_harness(config))
asyncio.run(run_harness(eval_cfg))


if __name__ == "__main__":
Expand Down
29 changes: 2 additions & 27 deletions claas/eval/config.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,7 @@
"""Hydra-based configuration for the evaluation harness."""

from __future__ import annotations

import dataclasses
import os
import re
from datetime import datetime, timezone
from pathlib import Path
"""Hydra schema registration for the evaluation harness."""

from hydra.core.config_store import ConfigStore

from .types import EvalConfig, HarnessConfig

# Pattern matching the timestamped run-id suffix (e.g. 20260220-012345Z)
_RUN_ID_RE = re.compile(r"\d{8}-\d{6}Z$")
from .types import EvalConfig

ConfigStore.instance().store(name="_eval_schema", node=EvalConfig)


def build_harness_config(eval_cfg: EvalConfig) -> HarnessConfig:
"""Post-process EvalConfig → HarnessConfig (no secrets)."""
fields = dataclasses.asdict(eval_cfg)

# Timestamped output subdir (skip if output_dir already ends with a run-id,
# which allows resuming an existing run by passing its directory).
output_dir = fields["output_dir"]
if not _RUN_ID_RE.search(Path(output_dir).name):
run_id = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%SZ")
fields["output_dir"] = os.path.join(output_dir, run_id)

return HarnessConfig(**fields)
9 changes: 8 additions & 1 deletion claas/eval/configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,15 @@ num_steps: 20
batch_size: 4
steps_per_batch: 4
feedback_repetitions: 1
training:
learning_rate: 3e-5
alpha: 0.5
is_clip: 5.0
max_grad_norm: 1.0
kl_reg_weight: 0.0
teacher_top_k: 100
seed: 42
lora_id_prefix: eval
output_dir: ./data/evals
output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ}

openclaw_url: http://localhost:18789
19 changes: 11 additions & 8 deletions claas/eval/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
from .plotting import generate_plots
from .preferences import PreferenceConfig, get_preference_configs
from .types import (
EvalConfig,
EvalMetrics,
ExperimentResult,
ExperimentSummary,
HarnessConfig,
LocalDistillMetrics,
MetricContext,
StepResult,
Expand All @@ -38,7 +38,7 @@
logger = logging.getLogger(__name__)


async def _init_lora(config: HarnessConfig, lora_id: str) -> str:
async def _init_lora(config: EvalConfig, lora_id: str) -> str:
"""Initialize a fresh LoRA adapter via CLaaS API."""
# LoRA init can exceed two minutes when the remote trainer is cold-starting.
async with httpx.AsyncClient(base_url=config.claas_url, timeout=300.0) as client:
Expand All @@ -51,7 +51,7 @@ async def _init_lora(config: HarnessConfig, lora_id: str) -> str:


async def _submit_feedback(
config: HarnessConfig,
config: EvalConfig,
lora_id: str,
samples: list[FeedbackItem],
) -> LocalDistillMetrics | TinkerDistillMetrics | None:
Expand Down Expand Up @@ -92,7 +92,7 @@ async def _submit_feedback(


async def _generate_response(
config: HarnessConfig,
config: EvalConfig,
model: str,
prompt: str,
temperature: float = 0,
Expand Down Expand Up @@ -127,7 +127,7 @@ async def _generate_response(


async def _measure_eval_metrics(
config: HarnessConfig,
config: EvalConfig,
pref: PreferenceConfig,
model_name: str,
step: int,
Expand Down Expand Up @@ -205,7 +205,7 @@ def _append_step_jsonl(output_dir: str, preference: str, step: StepResult) -> No
f.write(json.dumps(data) + "\n")


def _write_metadata(output_dir: str, preference: str, config: HarnessConfig, lora_id: str) -> None:
def _write_metadata(output_dir: str, preference: str, config: EvalConfig, lora_id: str) -> None:
"""Write experiment metadata."""
pref_dir = os.path.join(output_dir, preference)
os.makedirs(pref_dir, exist_ok=True)
Expand Down Expand Up @@ -273,7 +273,7 @@ def _write_summary(output_dir: str, results: list[ExperimentResult]) -> None:


async def run_preference_experiment(
config: HarnessConfig,
config: EvalConfig,
pref: PreferenceConfig,
enabled_metrics: list[Metric] | None = None,
needs_generation: bool = False,
Expand Down Expand Up @@ -357,6 +357,8 @@ async def run_preference_experiment(
)

# Main loop
training_cfg = config.training

for step in range(resume_from, config.num_steps):
step_start = time.perf_counter()

Expand Down Expand Up @@ -385,6 +387,7 @@ async def run_preference_experiment(
prompt=prompt,
response=content,
feedback=feedback_str,
training=training_cfg,
))
except (httpx.HTTPError, KeyError, ValueError) as e:
logger.warning(
Expand Down Expand Up @@ -471,7 +474,7 @@ async def run_preference_experiment(
return result


async def run_harness(config: HarnessConfig) -> None:
async def run_harness(config: EvalConfig) -> None:
"""Run the full evaluation harness."""
logging.basicConfig(
level=logging.INFO,
Expand Down
7 changes: 2 additions & 5 deletions claas/eval/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Optional

from claas.core.config import DEFAULT_SYSTEM_PROMPT
from claas.core.types import ChatMessage
from claas.core.types import ChatMessage, TrainingConfig


@dataclass
Expand Down Expand Up @@ -96,10 +96,7 @@ class EvalConfig:
batch_size: int = 4
steps_per_batch: int = 4
feedback_repetitions: int = 1


# HarnessConfig is the post-processed runtime config (still no secrets).
HarnessConfig = EvalConfig
training: TrainingConfig = field(default_factory=TrainingConfig)


@dataclass
Expand Down
Loading