diff --git a/synth_ai/sdk/optimization/internal/learning/client.py b/synth_ai/sdk/optimization/internal/learning/client.py index 25a5d22c..91280508 100644 --- a/synth_ai/sdk/optimization/internal/learning/client.py +++ b/synth_ai/sdk/optimization/internal/learning/client.py @@ -1,14 +1,37 @@ +import json import warnings -from collections.abc import Callable +from collections.abc import Callable, Mapping from contextlib import suppress from pathlib import Path -from typing import Any, TypedDict +from typing import Any, Optional, TypedDict from synth_ai.core.errors import HTTPError +from synth_ai.core.levers import ScopeKey from synth_ai.core.rust_core.http import RustCoreHttpClient, sleep +from synth_ai.sdk.optimization.models import LeverHandle, SensorFrame from synth_ai.sdk.shared.models import UnsupportedModelError, normalize_model_identifier +def _normalize_scope_payload(scope: list[Any] | None) -> Optional[list[dict[str, Any]]]: + if not scope: + return None + normalized: list[dict[str, Any]] = [] + for item in scope: + if isinstance(item, ScopeKey): + normalized.append(item.to_dict()) + elif isinstance(item, dict): + normalized.append(item) + return normalized or None + + +def _payload_to_dict(value: Any) -> dict[str, Any]: + if isinstance(value, Mapping): + return dict(value) + if hasattr(value, "to_dict"): + return value.to_dict() # type: ignore[attr-defined] + raise ValueError("Payload must be a dict or provide to_dict()") + + class LearningClient: """Client for learning/training jobs. @@ -154,6 +177,118 @@ async def poll_until_terminal( if max_seconds is not None and elapsed >= max_seconds: raise TimeoutError(f"Polling timed out after {elapsed} seconds for job {job_id}") + async def create_or_update_lever( + self, + optimizer_id: str, + lever: dict[str, Any] | LeverHandle, + ) -> LeverHandle: + """Create or update a lever handle for an optimizer. + + See: specifications/tanha/future/sensors_and_levers.txt + """ + payload = _payload_to_dict(lever) + url = f"/api/v1/optimizers/{optimizer_id}/levers" + async with RustCoreHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http: + js = await http.post_json(url, json=payload) + if not isinstance(js, dict): + raise HTTPError( + status=500, + url=url, + message="invalid_lever_response", + body_snippet=str(js)[:200], + ) + return LeverHandle.from_dict(js) + + async def resolve_lever( + self, + optimizer_id: str, + lever_id: str, + *, + scope: list[Any] | None = None, + snapshot: bool = True, + ) -> LeverHandle | None: + """Resolve a lever snapshot for the optimizer's scope. + + See: specifications/tanha/future/sensors_and_levers.txt + """ + params: dict[str, str] = {"lever_id": lever_id} + if snapshot: + params["snapshot"] = "true" + scope_payload = _normalize_scope_payload(scope) + if scope_payload: + params["scope"] = json.dumps(scope_payload) + url = f"/api/v1/optimizers/{optimizer_id}/levers" + async with RustCoreHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http: + js = await http.get(url, params=params if params else None) + if isinstance(js, dict): + return LeverHandle.from_dict(js) + return None + + async def emit_sensor_frame( + self, + optimizer_id: str, + frame: dict[str, Any] | SensorFrame, + ) -> SensorFrame: + """Emit a sensor frame payload for the optimizer. + + See: specifications/tanha/future/sensors_and_levers.txt + """ + payload = _payload_to_dict(frame) + url = f"/api/v1/optimizers/{optimizer_id}/sensors" + async with RustCoreHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http: + js = await http.post_json(url, json=payload) + if not isinstance(js, dict): + raise HTTPError( + status=500, + url=url, + message="invalid_sensor_frame_response", + body_snippet=str(js)[:200], + ) + return SensorFrame.from_dict(js) + + async def list_sensor_frames( + self, + optimizer_id: str, + *, + scope: list[Any] | None = None, + limit: int | None = None, + ) -> list[SensorFrame]: + """List sensor frames emitted for the optimizer scope. + + See: specifications/tanha/future/sensors_and_levers.txt + """ + params: dict[str, str] = {} + scope_payload = _normalize_scope_payload(scope) + if scope_payload: + params["scope"] = json.dumps(scope_payload) + if limit is not None: + params["limit"] = str(limit) + url = f"/api/v1/optimizers/{optimizer_id}/sensors" + async with RustCoreHttpClient(self._base_url, self._api_key, timeout=self._timeout) as http: + js = await http.get(url, params=params if params else None) + payloads: list[Any] = [] + if isinstance(js, list): + payloads = js + elif isinstance(js, dict): + candidates = ( + js.get("items") + or js.get("sensor_frames") + or js.get("frames") + or js.get("data") + or [] + ) + if isinstance(candidates, list): + payloads = candidates + frames: list[SensorFrame] = [] + for entry in payloads: + if not isinstance(entry, dict): + continue + try: + frames.append(SensorFrame.from_dict(entry)) + except ValueError: + continue + return frames + # --- Optional diagnostics --- async def pricing_preflight( self, *, job_type: str, gpu_type: str, estimated_seconds: float, container_count: int diff --git a/synth_ai/sdk/optimization/models.py b/synth_ai/sdk/optimization/models.py index 3faecf16..1fb71ef1 100644 --- a/synth_ai/sdk/optimization/models.py +++ b/synth_ai/sdk/optimization/models.py @@ -4,8 +4,8 @@ from enum import Enum from typing import Any, Dict, Iterable, Optional -from synth_ai.core.levers import MiproLeverSummary -from synth_ai.core.sensors import SensorFrameSummary +from synth_ai.core.levers import LeverKind, ScopeKey, MiproLeverSummary +from synth_ai.core.sensors import Sensor as CoreSensor, SensorFrameSummary def _first_present(data: Dict[str, Any], keys: Iterable[str]) -> Optional[Any]: @@ -106,7 +106,123 @@ def _extract_system_prompt( if result: return result - return None +@dataclass(slots=True) +class LeverHandle: + """SDK representation of a lever handle resolved by optimizer APIs. + + See: specifications/tanha/future/sensors_and_levers.txt + """ + + lever_id: str + kind: LeverKind + version: int + scope: list[ScopeKey] = field(default_factory=list) + metadata: dict[str, Any] = field(default_factory=dict) + + @classmethod + def from_dict(cls, raw: Dict[str, Any]) -> "LeverHandle": + if not isinstance(raw, dict): + raise ValueError("LeverHandle data must be an object") + lever_id = raw.get("lever_id") or raw.get("id") or "" + kind_raw = raw.get("kind") + try: + kind = LeverKind(str(kind_raw)) if kind_raw is not None else LeverKind.CUSTOM + except ValueError: + kind = LeverKind.CUSTOM + version_raw = raw.get("version") or raw.get("lever_version") + version = 0 + if version_raw is not None: + try: + version = int(version_raw) + except (TypeError, ValueError): + version = 0 + scope_raw = raw.get("scope") or [] + scope: list[ScopeKey] = [] + if isinstance(scope_raw, list): + for item in scope_raw: + if isinstance(item, dict): + scope.append(ScopeKey.from_dict(item)) + metadata = raw.get("metadata") if isinstance(raw.get("metadata"), dict) else {} + return cls( + lever_id=str(lever_id), + kind=kind, + version=version, + scope=scope, + metadata=metadata, + ) + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "lever_id": self.lever_id, + "kind": self.kind.value, + "version": self.version, + "scope": [item.to_dict() for item in self.scope], + } + if self.metadata: + payload["metadata"] = self.metadata + return payload + + +@dataclass(slots=True) +class SensorFrame: + """SDK representation of a sensor frame emitted by optimizer endpoints. + + See: specifications/tanha/future/sensors_and_levers.txt + """ + + scope: list[ScopeKey] = field(default_factory=list) + sensors: list[CoreSensor] = field(default_factory=list) + lever_versions: Dict[str, int] = field(default_factory=dict) + trace_ids: list[str] = field(default_factory=list) + frame_id: Optional[str] = None + metadata: dict[str, Any] = field(default_factory=dict) + created_at: Optional[str] = None + + @classmethod + def from_dict(cls, raw: Dict[str, Any]) -> "SensorFrame": + if not isinstance(raw, dict): + raise ValueError("SensorFrame data must be an object") + scope_raw = raw.get("scope") or [] + scope: list[ScopeKey] = [] + if isinstance(scope_raw, list): + for item in scope_raw: + if isinstance(item, dict): + scope.append(ScopeKey.from_dict(item)) + sensors_raw = raw.get("sensors") or [] + sensors: list[CoreSensor] = [] + if isinstance(sensors_raw, list): + for item in sensors_raw: + if isinstance(item, dict): + sensors.append(CoreSensor.from_dict(item)) + lever_versions = _parse_lever_versions(raw.get("lever_versions")) + trace_ids_raw = raw.get("trace_ids") or [] + trace_ids = [str(x) for x in trace_ids_raw if isinstance(x, (str, int))] if isinstance(trace_ids_raw, list) else [] + frame_id = raw.get("frame_id") if isinstance(raw.get("frame_id"), str) else None + metadata = raw.get("metadata") if isinstance(raw.get("metadata"), dict) else {} + created_at = raw.get("created_at") if isinstance(raw.get("created_at"), str) else None + return cls( + scope=scope, + sensors=sensors, + lever_versions=lever_versions, + trace_ids=trace_ids, + frame_id=frame_id, + metadata=metadata, + created_at=created_at, + ) + + def to_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "scope": [item.to_dict() for item in self.scope], + "sensors": [sensor.to_dict() for sensor in self.sensors], + "lever_versions": self.lever_versions, + "trace_ids": self.trace_ids, + "metadata": self.metadata, + } + if self.frame_id is not None: + payload["frame_id"] = self.frame_id + if self.created_at is not None: + payload["created_at"] = self.created_at + return payload class PolicyJobStatus(str, Enum): diff --git a/tests/optimization/test_learning_client_levers_sensors.py b/tests/optimization/test_learning_client_levers_sensors.py new file mode 100644 index 00000000..f3e309c4 --- /dev/null +++ b/tests/optimization/test_learning_client_levers_sensors.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import asyncio +import json +from typing import Any + +from synth_ai.sdk.optimization.internal.learning import client as learning_client +from synth_ai.sdk.optimization.internal.learning.client import LearningClient + + +def test_create_or_update_lever_posts_payload(monkeypatch) -> None: + captured: dict[str, Any] = {} + + class _FakeHttpClient: + def __init__(self, *_args: Any, **_kwargs: Any) -> None: + return + + async def __aenter__(self) -> "_FakeHttpClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + _ = (exc_type, exc, tb) + + async def post_json(self, path: str, json: dict[str, Any]) -> Any: + captured["path"] = path + captured["json"] = json + return { + "lever_id": json.get("lever_id", "lever_1"), + "kind": json.get("kind", "prompt"), + "version": json.get("version", 1), + "scope": json.get("scope", []), + } + + monkeypatch.setattr(learning_client, "RustCoreHttpClient", _FakeHttpClient) + client = LearningClient(base_url="https://example.com", api_key="key") + payload = { + "lever_id": "lever_123", + "kind": "prompt", + "version": 5, + "scope": [{"kind": "org", "id": "org_1"}], + } + result = asyncio.run(client.create_or_update_lever("opt_1", payload)) + assert captured["path"] == "/api/v1/optimizers/opt_1/levers" + assert captured["json"]["lever_id"] == "lever_123" + assert result.lever_id == "lever_123" + assert result.version == 5 + + +def test_resolve_lever_builds_query(monkeypatch) -> None: + captured: dict[str, Any] = {} + + class _FakeHttpClient: + def __init__(self, *_args: Any, **_kwargs: Any) -> None: + return + + async def __aenter__(self) -> "_FakeHttpClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + _ = (exc_type, exc, tb) + + async def get(self, path: str, params: dict[str, Any] | None = None) -> Any: + captured["path"] = path + captured["params"] = params + return { + "lever_id": "lever_456", + "kind": "constraint", + "version": 2, + "scope": json.loads(params["scope"]) if params and params.get("scope") else [], + } + + monkeypatch.setattr(learning_client, "RustCoreHttpClient", _FakeHttpClient) + client = LearningClient(base_url="https://example.com", api_key="key") + scope_param = [{"kind": "job", "id": "job_1"}] + result = asyncio.run(client.resolve_lever("opt_2", "lever_456", scope=scope_param)) + assert captured["path"] == "/api/v1/optimizers/opt_2/levers" + assert captured["params"]["lever_id"] == "lever_456" + assert json.loads(captured["params"]["scope"]) == scope_param + assert result is not None + assert result.kind.value == "constraint" + + +def test_emit_sensor_frame_returns_frame(monkeypatch) -> None: + payload_frame = { + "scope": [{"kind": "job", "id": "job_2"}], + "sensors": [ + { + "sensor_id": "reward.main", + "kind": "reward", + "scope": [{"kind": "job", "id": "job_2"}], + "value": {"reward": 0.7}, + } + ], + "lever_versions": {"lever_abc": 3}, + } + + class _FakeHttpClient: + def __init__(self, *_args: Any, **_kwargs: Any) -> None: + return + + async def __aenter__(self) -> "_FakeHttpClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + _ = (exc_type, exc, tb) + + async def post_json(self, path: str, json: dict[str, Any]) -> Any: + assert path == "/api/v1/optimizers/opt_3/sensors" + return { + **json, + "frame_id": "frame_42", + } + + monkeypatch.setattr(learning_client, "RustCoreHttpClient", _FakeHttpClient) + client = LearningClient(base_url="https://example.com", api_key="key") + frame = asyncio.run(client.emit_sensor_frame("opt_3", payload_frame)) + assert frame.frame_id == "frame_42" + assert frame.lever_versions["lever_abc"] == 3 + assert frame.sensors[0].sensor_id == "reward.main" + + +def test_list_sensor_frames_parses_list(monkeypatch) -> None: + responses = [ + { + "scope": [{"kind": "job", "id": "job_3"}], + "sensors": [], + "lever_versions": {}, + "frame_id": "frame_list", + } + ] + + class _FakeHttpClient: + def __init__(self, *_args: Any, **_kwargs: Any) -> None: + return + + async def __aenter__(self) -> "_FakeHttpClient": + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + _ = (exc_type, exc, tb) + + async def get(self, path: str, params: dict[str, Any] | None = None) -> Any: + assert path == "/api/v1/optimizers/opt_4/sensors" + assert params is not None and params.get("limit") == "10" + return responses + + monkeypatch.setattr(learning_client, "RustCoreHttpClient", _FakeHttpClient) + client = LearningClient(base_url="https://example.com", api_key="key") + frames = asyncio.run(client.list_sensor_frames("opt_4", limit=10)) + assert len(frames) == 1 + assert frames[0].frame_id == "frame_list"