From ab30a2d92608918d69a9bceca51e13b56234560a Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 19 Feb 2026 15:50:33 -0700 Subject: [PATCH 01/27] plans for model facade overhaul --- .../343/model-facade-overhaul-plan-step-1.md | 1221 +++++++++++++++++ ...del-facade-overhaul-plan-step-2-bedrock.md | 191 +++ 2 files changed, 1412 insertions(+) create mode 100644 plans/343/model-facade-overhaul-plan-step-1.md create mode 100644 plans/343/model-facade-overhaul-plan-step-2-bedrock.md diff --git a/plans/343/model-facade-overhaul-plan-step-1.md b/plans/343/model-facade-overhaul-plan-step-1.md new file mode 100644 index 000000000..b3523f048 --- /dev/null +++ b/plans/343/model-facade-overhaul-plan-step-1.md @@ -0,0 +1,1221 @@ +--- +date: 2026-02-19 +authors: + - nmulepati +--- + +# Model Facade Overhaul Plan: Step 1 (Non-Bedrock) + +This document proposes a concrete migration plan to replace LiteLLM in Data Designer while keeping the public behavior of `ModelFacade` stable. + +The short version: + +1. Introduce a provider-agnostic client interface under `engine/models/clients/`. +2. Keep `ModelFacade` API and call sites unchanged. +3. Implement adapters for `openai`-compatible APIs first, then Anthropic. +4. Migrate error handling and retries into native code. +5. Remove LiteLLM only after contract-test parity is proven. + +## Reviewer Snapshot + +Reviewers should validate three things first: + +1. `ModelFacade` public behavior does not regress (API shape, MCP loop behavior, usage accounting, user-facing errors). +2. Provider-specific concerns are isolated inside adapters (`openai_compatible`, `anthropic`) behind canonical request/response types. +3. Rollout is reversible with feature flag and bridge adapter until parity is proven. + +## Architecture Diagram + +### 1. Structural view (boundaries and ownership) + +```text +Callers + - Column generators + - ModelRegistry health checks + | + v ++---------------------------------------------------------------+ +| ModelFacade (public surface; unchanged) | +| - generate/agenerate loops | +| - MCP tool loop + correction/restart | +| - usage aggregation + user-facing error context | ++------------------------------+--------------------------------+ + | + v ++---------------------------------------------------------------+ +| Model Client Layer (new) | +| | +| +----------------------+ +-------------------------------+ | +| | Client Factory |-->| Adapter selected by | | +| | - provider_type | | provider_type | | +| | - auth parsing | +---------------+--------------+ | +| +----------+-----------+ | | +| | | | +| v v | +| +----------------------+ +--------------------------+ | +| | Throttle Manager |<---->| Retry Engine | | +| | - global cap key | | - jittered backoff | | +| | - domain key | | - retry classifier | | +| +----------+-----------+ +------------+-------------+ | +| | | | +| +---------------+---------------+ | +| v | +| +-------------------------+ | +| | Adapter implementation | | +| | - OpenAI compatible | | +| | - Anthropic | | +| | - LiteLLM bridge (temp) | | +| +------------+------------+ | ++-------------------------------|------------------------------+ + v + Provider HTTP APIs +``` + +### 2. Runtime sequence (sync/async happy path) + +```text +1) Caller -> ModelFacade.generate/agenerate(...) +2) ModelFacade builds canonical request (types.py) +3) ModelFacade -> ModelClient (adapter chosen by factory) +4) Throttle acquire using: + - global key: (provider_name, model_identifier) + - domain key: (provider_name, model_identifier, throttle_domain) +5) Adapter request with resolved auth context +6) Retry engine executes outbound call (httpx) +7) Provider returns response +8) Adapter normalizes response -> canonical response +9) Throttle release_success + additive recovery +10) ModelFacade applies parser/MCP logic, updates usage stats +11) Return parsed output + trace +``` + +### 3. Runtime sequence (429/throttling path) + +```text +1) Provider returns 429 / throttling error +2) Retry engine classifies RATE_LIMIT and extracts Retry-After (if present) +3) Throttle release_rate_limited: + - multiplicative decrease on domain current_limit + - set blocked_until cooldown +4) Retry re-enters throttle acquire before next attempt +5) On recovery, additive increase restores capacity up to effective max +``` + +## Concrete Implementation Plan (Reviewer-Oriented) + +### File-level change map + +New files (Step 1): + +1. `packages/data-designer-engine/src/data_designer/engine/models/clients/base.py` +2. `packages/data-designer-engine/src/data_designer/engine/models/clients/types.py` +3. `packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py` +4. `packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py` +5. `packages/data-designer-engine/src/data_designer/engine/models/clients/throttle.py` +6. `packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py` +7. `packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py` +8. `packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py` +9. `packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py` + +Updated files (Step 1): + +1. `packages/data-designer-engine/src/data_designer/engine/models/facade.py` +2. `packages/data-designer-engine/src/data_designer/engine/models/errors.py` +3. `packages/data-designer-engine/src/data_designer/engine/models/factory.py` +4. `packages/data-designer-config/src/data_designer/config/models.py` (auth schema extension) +5. `packages/data-designer/src/data_designer/cli/forms/provider_builder.py` (provider-specific auth input) +6. `packages/data-designer-config/src/data_designer/lazy_heavy_imports.py` (remove `litellm` after cutover) + +### PR slicing (recommended) + +1. PR-1: canonical types/interfaces/errors + bridge adapter + no behavior change. +2. PR-2: `ModelFacade` switched to `ModelClient` + parity tests passing on bridge. +3. PR-3: OpenAI-compatible adapter + retry + throttle + auth integration. +4. PR-4: Anthropic adapter + auth integration + capability gating. +5. PR-5: CLI/config schema updates + docs + migration guards. +6. PR-6: Cutover flag default flip + LiteLLM removal for Step 1 scope. + +### Reviewer checklist per PR + +1. Are external method signatures unchanged on `ModelFacade`? +2. Are error classes unchanged at user-facing boundaries? +3. Are sync and async paths symmetric in behavior? +4. Does adaptive throttling honor global cap and domain key rules? +5. Is any secret material exposed in logs or reprs? +6. Is rollback possible via feature flag in the same PR? + +## Why This Plan + +Current usage is concentrated and replaceable: + +1. `packages/data-designer-engine/src/data_designer/engine/models/facade.py` +2. `packages/data-designer-engine/src/data_designer/engine/models/errors.py` +3. `packages/data-designer-engine/src/data_designer/engine/models/litellm_overrides.py` +4. `packages/data-designer-engine/src/data_designer/engine/models/factory.py` + +That makes this a good candidate for a strangler migration: preserve the outer behavior, replace internals incrementally. + +## Current Responsibilities To Preserve + +`ModelFacade` currently provides these behaviors and they must remain stable: + +1. Sync and async methods: + - `completion` / `acompletion` + - `generate` / `agenerate` + - `generate_text_embeddings` / `agenerate_text_embeddings` + - `generate_image` / `agenerate_image` +2. Prompt/message conversion and multimodal context handling. +3. MCP tool-calling loop behavior, including tool-turn limits and refusal flow. +4. Usage tracking (`token_usage`, `request_usage`, `image_usage`, `tool_usage`). +5. Exception normalization into `DataDesignerError` subclasses. +6. Provider-level `extra_body` and `extra_headers` merge semantics. + +## Target Architecture + +### 1. New model client layer + +Add a new package: + +`packages/data-designer-engine/src/data_designer/engine/models/clients/` + +Suggested files: + +1. `base.py` - Protocols / interfaces +2. `types.py` - Canonical request/response objects +3. `errors.py` - Provider-agnostic transport/provider exceptions +4. `retry.py` - Backoff policy and retry decision logic +5. `factory.py` - Adapter selection by `provider_type` +6. `adapters/openai_compatible.py` +7. `adapters/anthropic.py` +8. `adapters/litellm_bridge.py` (temporary bridge for migration safety) + +### 2. Keep `ModelFacade` as orchestrator + +`ModelFacade` should continue to orchestrate: + +1. Parser correction/restart loops +2. MCP tool loop +3. Usage aggregation +4. High-level convenience methods + +`ModelFacade` should stop depending directly on LiteLLM response classes. + +### 3. Transport stack + +Use shared transport components: + +1. `httpx.Client` / `httpx.AsyncClient` for HTTP adapters +2. Shared retry module to preserve current exponential backoff and jitter behavior + +## Canonical Types (Adapter Contract) + +Define provider-agnostic types so `ModelFacade` can consume one shape regardless of provider. + +```python +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class Usage: + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None + + +@dataclass +class ImagePayload: + # canonical output to upper layers is base64 without data URI prefix + b64_data: str + mime_type: str | None = None + + +@dataclass +class ToolCall: + id: str + name: str + arguments_json: str + + +@dataclass +class AssistantMessage: + content: str | None = None + reasoning_content: str | None = None + tool_calls: list[ToolCall] = field(default_factory=list) + images: list[ImagePayload] = field(default_factory=list) + + +@dataclass +class ChatCompletionRequest: + model: str + messages: list[dict[str, Any]] + tools: list[dict[str, Any]] | None = None + temperature: float | None = None + top_p: float | None = None + max_tokens: int | None = None + timeout: float | None = None + extra_body: dict[str, Any] | None = None + extra_headers: dict[str, str] | None = None + metadata: dict[str, Any] | None = None + + +@dataclass +class ChatCompletionResponse: + message: AssistantMessage + usage: Usage | None = None + raw: Any | None = None + + +@dataclass +class EmbeddingRequest: + model: str + inputs: list[str] + encoding_format: str | None = None + dimensions: int | None = None + timeout: float | None = None + extra_body: dict[str, Any] | None = None + extra_headers: dict[str, str] | None = None + + +@dataclass +class EmbeddingResponse: + vectors: list[list[float]] + usage: Usage | None = None + raw: Any | None = None + + +@dataclass +class ImageGenerationRequest: + model: str + prompt: str + messages: list[dict[str, Any]] | None = None + n: int | None = None + timeout: float | None = None + extra_body: dict[str, Any] | None = None + extra_headers: dict[str, str] | None = None + + +@dataclass +class ImageGenerationResponse: + images: list[ImagePayload] + usage: Usage | None = None + raw: Any | None = None +``` + +Notes: + +1. `raw` exists for diagnostics/logging only. +2. Canonical image output is always base64 payload. +3. Tool calls are normalized to `id/name/arguments_json`. + +## Adapter Interfaces + +Use explicit interfaces so capabilities are clear. + +```python +from __future__ import annotations + +from typing import Protocol + + +class ChatCompletionClient(Protocol): + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: ... + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: ... + + +class EmbeddingClient(Protocol): + def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: ... + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: ... + + +class ImageGenerationClient(Protocol): + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: ... + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: ... + + +class ModelClient(ChatCompletionClient, EmbeddingClient, ImageGenerationClient, Protocol): + provider_name: str + def supports_chat_completion(self) -> bool: ... + def supports_embeddings(self) -> bool: ... + def supports_image_generation(self) -> bool: ... +``` + +Capability checks are important because not all providers/models support all operations. + +## Provider Adapter Shapes + +### OpenAI-compatible adapter + +File: `clients/adapters/openai_compatible.py` + +Scope: + +1. OpenAI REST-compatible endpoints +2. NVIDIA Integrate endpoints configured as `provider_type="openai"` +3. OpenRouter and similar gateways with compatible request/response shape + +Implementation expectations: + +1. Chat completion: + - `POST /chat/completions` + - map canonical `messages`, `tools`, and generation params directly +2. Embeddings: + - `POST /embeddings` +3. Image generation: + - Primary: `POST /images/generations` + - Fallback mode: chat-completion image extraction for autoregressive models +4. Parse usage from provider response if present. +5. Normalize tool calls and reasoning fields. +6. Normalize image outputs from either `b64_json`, data URI, or URL download. + +### Anthropic adapter + +File: `clients/adapters/anthropic.py` + +Scope: + +1. Anthropic Messages API and tool use +2. Streaming can be deferred for phase 1 + +Implementation expectations: + +1. Chat completion: + - map system + messages into Anthropic message schema + - map tool schemas and tool-use response blocks to canonical `ToolCall` +2. Reasoning/thinking content: + - map into `reasoning_content` if provider returns a separate channel +3. Unsupported capability handling: + - if embeddings or images are not available/configured, raise canonical unsupported error + +Bedrock adapter planning is intentionally deferred to Step 2: + +1. `plans/343/model-facade-overhaul-plan-step-2-bedrock.md` + +## Authentication and Credential Schema + +Auth should be explicit and provider-specific. Today `ModelProvider` has one `api_key` field, but native adapters need different credential shapes. + +### Auth design goals + +1. Strongly typed auth config per provider. +2. Backward compatibility for existing `api_key` users. +3. Secret values resolved only at runtime via `SecretResolver`. +4. No secret material persisted in logs, traces, or exceptions. + +### Proposed config model evolution + +Keep current fields for compatibility: + +1. `provider_type` +2. `endpoint` +3. `api_key` + +Add optional provider-specific `auth` object: + +```yaml +model_providers: + - name: openai-prod + provider_type: openai + endpoint: https://api.openai.com/v1 + auth: + mode: api_key + api_key: OPENAI_API_KEY + organization: org_abc123 + project: proj_abc123 + + - name: anthropic-prod + provider_type: anthropic + endpoint: https://api.anthropic.com + auth: + mode: api_key + api_key: ANTHROPIC_API_KEY + anthropic_version: "2023-06-01" +``` + +Back-compat rule: + +1. If `auth` is absent and `api_key` is present: + - for `openai` and `anthropic`, treat as `auth.mode=api_key` + +### Proposed typed auth schema + +```python +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel + + +class OpenAIAuth(BaseModel): + mode: Literal["api_key"] = "api_key" + api_key: str + organization: str | None = None + project: str | None = None + + +class AnthropicAuth(BaseModel): + mode: Literal["api_key"] = "api_key" + api_key: str + anthropic_version: str = "2023-06-01" +``` + +### Adapter auth behavior by provider + +#### OpenAI-compatible + +Headers: + +1. `Authorization: Bearer ` +2. Optional `OpenAI-Organization: ` +3. Optional `OpenAI-Project: ` + +Credentials: + +1. Resolve `api_key` through `SecretResolver` +2. Cache in-memory per client instance +3. Never log header values + +#### Anthropic + +Headers: + +1. `x-api-key: ` +2. `anthropic-version: ` +3. `content-type: application/json` + +Credentials: + +1. Resolve `api_key` through `SecretResolver` +2. `anthropic_version` can be defaulted in config model + +### Auth resolution flow + +At model client creation time: + +1. Read `ModelProvider.provider_type` +2. Parse/validate `auth` object for that provider type +3. Resolve secret references with `SecretResolver` +4. Build adapter-specific auth context +5. Instantiate adapter client with immutable auth context + +### Auth error normalization + +Map provider auth failures into canonical errors: + +1. OpenAI-compatible `401/403` -> `ProviderError(kind=AUTHENTICATION | PERMISSION_DENIED)` +2. Anthropic `401/403` -> same mapping + +Then map canonical provider errors to existing Data Designer user-facing errors: + +1. `ModelAuthenticationError` +2. `ModelPermissionDeniedError` + +### Secret handling and logging rules + +1. Never log resolved secret values. +2. Redact auth headers if request logging is enabled. +3. Redact any accidental credential-like substrings in exception messages. +4. Avoid storing secrets in dataclass `repr` by using custom `__repr__` or redaction wrappers. + +### Migration plan for auth schema + +#### Phase A + +1. Add optional `auth` field to `ModelProvider`. +2. Keep `api_key` as fallback. +3. Implement adapter builders that accept both forms. + +#### Phase B + +1. Update CLI provider flow to collect provider-specific auth fields. +2. Add validation messages tailored to provider type. + +#### Phase C + +1. Deprecate top-level `api_key` once migration is complete. +2. Keep a compatibility shim for one release cycle. + +## Concrete Adapter Skeletons + +These are intentionally close to implementation shape. + +### Shared adapter base + +```python +from __future__ import annotations + +from typing import Any + +import httpx + +from data_designer.engine.models.clients.errors import ProviderError +from data_designer.engine.models.clients.retry import RetryPolicy, run_with_retries +from data_designer.engine.models.clients.types import ( + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ImageGenerationRequest, + ImageGenerationResponse, +) + + +class HTTPAdapterBase: + def __init__( + self, + *, + provider_name: str, + endpoint: str, + api_key: str | None, + default_headers: dict[str, str] | None = None, + retry_policy: RetryPolicy | None = None, + timeout_s: float = 60.0, + ) -> None: + self.provider_name = provider_name + self.endpoint = endpoint.rstrip("/") + self.api_key = api_key + self.default_headers = default_headers or {} + self.retry_policy = retry_policy or RetryPolicy.default() + self.timeout_s = timeout_s + self._client = httpx.Client(timeout=self.timeout_s) + self._aclient = httpx.AsyncClient(timeout=self.timeout_s) + + def _headers(self, extra_headers: dict[str, str] | None = None) -> dict[str, str]: + headers = dict(self.default_headers) + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + if extra_headers: + headers.update(extra_headers) + return headers + + def close(self) -> None: + self._client.close() + + async def aclose(self) -> None: + await self._aclient.aclose() +``` + +### OpenAI-compatible adapter skeleton + +```python +class OpenAICompatibleClient(HTTPAdapterBase, ModelClient): + def supports_chat_completion(self) -> bool: + return True + + def supports_embeddings(self) -> bool: + return True + + def supports_image_generation(self) -> bool: + return True + + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + payload = { + "model": request.model, + "messages": request.messages, + "tools": request.tools, + "temperature": request.temperature, + "top_p": request.top_p, + "max_tokens": request.max_tokens, + } + payload = {k: v for k, v in payload.items() if v is not None} + if request.extra_body: + payload.update(request.extra_body) + + response_json = run_with_retries( + fn=lambda: self._post_json("/chat/completions", payload, request.extra_headers), + policy=self.retry_policy, + ) + return parse_openai_chat_response(response_json) + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + ... + + def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + ... + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + ... + + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + ... + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + ... +``` + +### Anthropic adapter skeleton + +```python +class AnthropicClient(HTTPAdapterBase, ModelClient): + def supports_chat_completion(self) -> bool: + return True + + def supports_embeddings(self) -> bool: + return False + + def supports_image_generation(self) -> bool: + return False + + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + payload = anthropic_payload_from_canonical(request) + response_json = run_with_retries( + fn=lambda: self._post_json("/v1/messages", payload, anthropic_headers(request.extra_headers)), + policy=self.retry_policy, + ) + return parse_anthropic_chat_response(response_json) + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + ... + + def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + raise ProviderError.unsupported_capability(provider_name=self.provider_name, operation="embeddings") + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + raise ProviderError.unsupported_capability(provider_name=self.provider_name, operation="embeddings") + + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + raise ProviderError.unsupported_capability(provider_name=self.provider_name, operation="image-generation") + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + raise ProviderError.unsupported_capability(provider_name=self.provider_name, operation="image-generation") +``` + +## Request and Response Mapping Details + +### Canonical -> OpenAI-compatible chat payload + +| Canonical field | OpenAI payload field | +|---|---| +| `model` | `model` | +| `messages` | `messages` | +| `tools` | `tools` | +| `temperature` | `temperature` | +| `top_p` | `top_p` | +| `max_tokens` | `max_tokens` | +| `extra_body` | merged into payload | +| `extra_headers` | request headers | + +OpenAI-compatible response parsing: + +1. `choices[0].message.content` -> canonical `message.content` +2. `choices[0].message.tool_calls[*]` -> canonical `ToolCall` +3. `choices[0].message.reasoning_content` if present -> canonical `reasoning_content` +4. `usage.prompt_tokens/completion_tokens` -> canonical `Usage` + +### Canonical -> Anthropic messages payload + +| Canonical field | Anthropic field | +|---|---| +| system message in `messages` | top-level `system` | +| user/assistant/tool messages | `messages` blocks | +| `tools` | `tools` | +| `max_tokens` | `max_tokens` | +| `temperature` | `temperature` | +| `top_p` | `top_p` | +| `extra_body` | merged into payload | + +Anthropic response parsing: + +1. text blocks -> canonical `message.content` +2. tool_use blocks -> canonical `ToolCall` +3. thinking/reasoning blocks when available -> canonical `reasoning_content` +4. usage fields -> canonical `Usage` + +## Client Factory and Provider Resolution + +Add a dedicated factory: + +`packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py` + +Factory inputs: + +1. `ModelConfig` +2. `ModelProvider` +3. `SecretResolver` + +Factory output: + +1. A concrete `ModelClient` implementation + +Routing logic: + +1. `provider_type == "openai"` -> `OpenAICompatibleClient` +2. `provider_type == "anthropic"` -> `AnthropicClient` +3. unknown -> `ValueError` with supported provider types + +Migration safety option: + +1. `provider_type == "litellm-bridge"` or feature flag -> `LiteLLMBridgeClient` +2. lets us verify new abstraction without immediate provider rewrite + +## Error Model and Mapping + +Current `errors.py` pattern-matches LiteLLM exception types. Replace this with canonical provider exceptions. + +### Canonical provider exception + +```python +from dataclasses import dataclass +from enum import Enum + + +class ProviderErrorKind(str, Enum): + API_ERROR = "api_error" + API_CONNECTION = "api_connection" + AUTHENTICATION = "authentication" + CONTEXT_WINDOW_EXCEEDED = "context_window_exceeded" + UNSUPPORTED_PARAMS = "unsupported_params" + BAD_REQUEST = "bad_request" + INTERNAL_SERVER = "internal_server" + NOT_FOUND = "not_found" + PERMISSION_DENIED = "permission_denied" + RATE_LIMIT = "rate_limit" + TIMEOUT = "timeout" + UNPROCESSABLE_ENTITY = "unprocessable_entity" + UNSUPPORTED_CAPABILITY = "unsupported_capability" + + +@dataclass +class ProviderError(Exception): + kind: ProviderErrorKind + message: str + status_code: int | None = None + provider_name: str | None = None + model_name: str | None = None + cause: Exception | None = None +``` + +Then update `handle_llm_exceptions` to match on `ProviderError.kind` and raise existing user-facing `DataDesignerError` subclasses. + +This preserves the public error model while removing LiteLLM-specific coupling. + +## Retry and Backoff + +Replicate current semantics from `LiteLLMRouterDefaultKwargs` and `CustomRouter`: + +1. `initial_retry_after_s = 2.0` +2. `jitter_pct = 0.2` +3. retries at least for `rate_limit` and `timeout` (currently 3) +4. respect provider `Retry-After` when present and reasonable + +Implement this in one shared module used by all adapters. + +## Adaptive Throttling (429-Aware, Sync + Async) + +In addition to retries, adapters should dynamically reduce concurrency when providers return `429`/throttling errors. + +### Design goals + +1. Respect configured `max_parallel_requests` as a hard upper bound. +2. Auto-throttle down on sustained throttling and recover gradually. +3. Use the same throttle state across sync and async code paths. +4. Share throttling across model configs that target the same provider + model identifier. + +### Throttle key and scope + +Use two related keys: + +1. Global cap key: `(provider_name, model_identifier)` +2. Throttle domain key: `(provider_name, model_identifier, throttle_domain)` + +`model_identifier` is the model id from model config (for example, `gpt-5`). + +`throttle_domain` is derived from the actual backend route: + +1. `chat` for chat/completions-backed traffic (including autoregressive image generation via chat) +2. `embedding` for embedding endpoint traffic +3. `image` for dedicated image generation endpoint traffic + +This allows multimodal models to share budget when they use the same upstream route and to separate budgets when routes differ. + +### Effective max concurrency across model configs + +When multiple model configs map to the same global cap key: + +1. Compute `effective_max_parallel_requests = min(max_parallel_requests across those model configs)`. +2. Use that effective max as the hard cap for all associated throttle domains. + +This enforces the most conservative configured limit for shared upstream capacity. + +### Shared throttle state + +Store shared state at both levels: + +1. `GlobalCapState` per `(provider_name, model_identifier)`: + - `configured_limits_by_alias` + - `effective_max_limit` (minimum of registered alias limits) +2. `DomainThrottleState` per `(provider_name, model_identifier, throttle_domain)`: + - `current_limit` (`1..effective_max_limit`) + - `in_flight` + - `blocked_until_monotonic` + - `success_streak` + +Each domain state must clamp to the current global `effective_max_limit` if that value decreases. + +State mutation must be thread-safe (`threading.Lock`) and use monotonic time. + +### Execution-model agnostic core API + +Core state methods should be non-blocking so both sync/async wrappers can reuse them: + +1. `try_acquire(now_monotonic) -> float` where return is `wait_seconds` (`0` means acquired) +2. `release_success(now_monotonic) -> None` +3. `release_rate_limited(now_monotonic, retry_after_seconds) -> None` +4. `release_failure(now_monotonic) -> None` + +### Sync and async wrappers + +Use thin wrappers around the same state: + +1. `acquire_sync()` loops with `time.sleep(wait_seconds)` until acquire. +2. `acquire_async()` loops with `await asyncio.sleep(wait_seconds)` until acquire. + +This ensures sync and async traffic co-throttle against the same provider/model throttle domain budget. + +### Adjustment policy + +Use additive-increase / multiplicative-decrease (AIMD): + +1. On `429`/throttling: + - `current_limit = max(1, floor(current_limit * 0.5))` + - set `blocked_until` using `Retry-After` when available, else default backoff + - reset `success_streak` +2. On success: + - increment `success_streak` + - after `success_window` successes, increase `current_limit` by `+1` until `effective_max_limit` +3. On non-429 failures: + - no immediate drop unless configured; still release in-flight slot + +### Provider-specific throttling signals + +1. OpenAI-compatible: HTTP `429`, parse `Retry-After` (seconds/date). +2. Anthropic: HTTP `429`, parse provider headers if present. + +### Integration points + +1. Acquire throttle slot immediately before outbound request attempt. +2. Release slot on completion in `finally`. +3. Apply `release_rate_limited(...)` in retry classifier when error kind is `RATE_LIMIT`. +4. Re-enter acquire path for each retry attempt (so retries also obey adaptive limits). +5. Register each `ModelFacade` limit contribution into `GlobalCapState` during initialization. + +### Config knobs (optional, with safe defaults) + +1. `adaptive_throttle_enabled: bool = true` +2. `adaptive_throttle_min_parallel: int = 1` +3. `adaptive_throttle_reduce_factor: float = 0.5` +4. `adaptive_throttle_success_window: int = 50` +5. `adaptive_throttle_default_block_seconds: float = 2.0` + +### Compatibility expectation with existing config + +1. `max_parallel_requests` remains user-facing and authoritative upper bound. +2. If multiple model configs use same provider+model, lower `max_parallel_requests` wins (minimum rule). +3. If adaptive throttling is disabled, behavior reverts to fixed concurrency at the effective max. +4. No public API changes required in `ModelFacade`. + +## `ModelFacade` Refactor Plan + +Minimal diff approach: + +1. Keep public methods and signatures unchanged. +2. Replace `_router` with `_client: ModelClient`. +3. Convert `ChatMessage` list into canonical request. +4. Consume canonical response shape. +5. Preserve all MCP logic exactly as-is. +6. Move usage tracking methods to consume canonical `Usage`. + +Expected code updates: + +1. `facade.py`: swap transport layer calls and response parsing. +2. `factory.py`: initialize client factory instead of applying LiteLLM patches. +3. `errors.py`: map from canonical `ProviderError` instead of LiteLLM exception classes. +4. `lazy_heavy_imports.py`: remove `litellm` entry after complete cutover. + +## Capability Matrix + +Introduce explicit capability checks at adapter and model level. + +| Capability | OpenAI-Compatible | Anthropic | +|---|---|---| +| Chat completion | Yes | Yes | +| Tool calls | Yes | Yes | +| Embeddings | Yes (endpoint/model dependent) | Provider dependent | +| Image generation (diffusion endpoint) | Yes (provider dependent) | Provider dependent | +| Image via chat completion | Some models | Provider dependent | + +If unsupported at runtime: + +1. Return canonical `ProviderError(kind=UNSUPPORTED_CAPABILITY, ...)` +2. Surface as `ModelUnsupportedParamsError` or a dedicated model capability error + +Bedrock capability planning is in Step 2: + +1. `plans/343/model-facade-overhaul-plan-step-2-bedrock.md` + +## Compatibility Matrix + +This matrix defines expected parity between the current LiteLLM-backed implementation and the native adapter implementation. + +### API behavior parity (`ModelFacade`) + +| Surface | Current (LiteLLM) | Target (Native Adapters) | Compatibility expectation | +|---|---|---|---| +| `completion(messages, **kwargs)` | Router chat completion call | `ModelClient.completion(ChatCompletionRequest)` | Same public method signature, same return semantics consumed by `generate` | +| `acompletion(messages, **kwargs)` | Async router chat completion | `ModelClient.acompletion(ChatCompletionRequest)` | Same public method signature and async behavior | +| `generate(prompt, parser, ...)` | Completion + parser correction loop + MCP loop | Same orchestration, adapter-backed completion | Same correction/restart/tool-turn behavior and same trace shape | +| `agenerate(prompt, parser, ...)` | Async completion + async MCP handling | Same orchestration, adapter-backed async completion | Same behavior and error contracts as sync path | +| `generate_text_embeddings(input_texts, **kwargs)` | Router embedding | `ModelClient.embeddings(EmbeddingRequest)` | Same output type (`list[list[float]]`) and length checks | +| `agenerate_text_embeddings(input_texts, **kwargs)` | Async router embedding | `ModelClient.aembeddings(EmbeddingRequest)` | Same output type and error behavior | +| `generate_image(prompt, ...)` | Chat-completion image path or diffusion path | `ModelClient.generate_image(ImageGenerationRequest)` | Same output contract: list of base64 strings | +| `agenerate_image(prompt, ...)` | Async chat/diffusion path | `ModelClient.agenerate_image(ImageGenerationRequest)` | Same output contract and usage update behavior | +| `consolidate_kwargs(**kwargs)` | Merge model inference params + provider extra fields | Same logic before request conversion | No behavioral drift in precedence rules | + +### Error mapping parity + +| Current user-facing error | Current trigger source | Native trigger source | Compatibility expectation | +|---|---|---|---| +| `ModelAPIError` | LiteLLM API-level exceptions | `ProviderError(kind=API_ERROR)` | Same class, equivalent message quality | +| `ModelAPIConnectionError` | LiteLLM connection exceptions | `ProviderError(kind=API_CONNECTION)` | Same class and retryability semantics | +| `ModelAuthenticationError` | LiteLLM auth errors / some 403 API errors | `ProviderError(kind=AUTHENTICATION)` | Same class with provider-specific auth guidance | +| `ModelPermissionDeniedError` | LiteLLM permission denied | `ProviderError(kind=PERMISSION_DENIED)` | Same class and cause semantics | +| `ModelRateLimitError` | LiteLLM rate-limit errors | `ProviderError(kind=RATE_LIMIT)` | Same class and backoff behavior | +| `ModelTimeoutError` | LiteLLM timeout | `ProviderError(kind=TIMEOUT)` | Same class and retry policy | +| `ModelBadRequestError` | LiteLLM bad request | `ProviderError(kind=BAD_REQUEST)` | Same class and actionable remediation | +| `ModelUnsupportedParamsError` | LiteLLM unsupported params | `ProviderError(kind=UNSUPPORTED_PARAMS or UNSUPPORTED_CAPABILITY)` | Same class; message may include capability context | +| `ModelNotFoundError` | LiteLLM not found | `ProviderError(kind=NOT_FOUND)` | Same class | +| `ModelInternalServerError` | LiteLLM internal server | `ProviderError(kind=INTERNAL_SERVER)` | Same class | +| `ModelUnprocessableEntityError` | LiteLLM unprocessable entity | `ProviderError(kind=UNPROCESSABLE_ENTITY)` | Same class | +| `ModelContextWindowExceededError` | LiteLLM context overflow | `ProviderError(kind=CONTEXT_WINDOW_EXCEEDED)` | Same class with context-width hints | +| `ModelGenerationValidationFailureError` | Parser correction exhaustion | Same parser correction exhaustion path | Same class and retry/correction semantics | +| `ImageGenerationError` | Image extraction or empty image payload | Same canonical image extraction validations | Same class and failure conditions | + +### Usage and telemetry parity + +| Metric behavior | Current | Target | Compatibility expectation | +|---|---|---|---| +| Request success/failure tracking | Tracked in usage stats | Tracked from canonical response/error outcomes | Same counters and aggregation logic | +| Token usage for chat | From LiteLLM usage fields | From canonical `Usage` parsed by adapter | Same when provider reports usage; graceful fallback when omitted | +| Token usage for embeddings | From LiteLLM usage fields | From canonical `Usage` parsed by adapter | Same behavior | +| Token usage for image generation | From LiteLLM image usage (when present) | From canonical `Usage` parsed by adapter | Same behavior; request counts still update if token usage absent | +| Tool usage tracking | Managed in `generate/agenerate` loops | Unchanged in `ModelFacade` | Exact parity expected | +| Image count usage | Counted from returned images | Counted from canonical image payloads | Exact parity expected | + +### Configuration compatibility + +| Config surface | Current | Target | Compatibility expectation | +|---|---|---|---| +| `ModelProvider.provider_type` | Free-form string | Enumerated known values in factory (later schema hardening) | Existing `"openai"` continues to work unchanged | +| `ModelProvider.api_key` | Top-level optional field | Supported as fallback for auth | Backward compatible during migration window | +| `ModelProvider.auth` | Not present | Optional provider-specific auth object | Additive, non-breaking introduction | +| `extra_headers` | Provider-level dict | Preserved and merged into adapter request | Same precedence behavior | +| `extra_body` | Provider/model kwargs passthrough | Preserved and merged into canonical request payload | Same precedence behavior | +| `inference_parameters.timeout` | Passed through to LiteLLM kwargs | Passed to adapter transport/request timeout | Same intent and default behavior | + +### Non-goals for strict parity + +These are allowed differences and should be documented when encountered: + +1. Provider-specific raw response formats in debug logs. +2. Exact wording of low-level upstream exception strings. +3. Minor token accounting differences when providers omit or redefine usage fields. +4. Unsupported capability messages that are more explicit than current generic errors. + +## Config and CLI Evolution + +Current `ModelProvider.provider_type` is free-form string. Tighten this in phases: + +### Phase A (non-breaking) + +1. Keep `str` but validate known values in factory. +2. Emit warning for unknown values. + +### Phase B (breaking-ready) + +1. Migrate to a constrained literal/enum: + - `"openai"` + - `"anthropic"` +2. Update CLI provider form with controlled options and validation. + +Related files: + +1. `packages/data-designer-config/src/data_designer/config/models.py` +2. `packages/data-designer/src/data_designer/cli/forms/provider_builder.py` + +## Testing Strategy + +### 1. Contract tests for `ModelFacade` + +Goal: prove behavior parity independent of backend. + +1. Keep existing `test_facade.py` behavior assertions. +2. Parametrize backend selection (`litellm_bridge` vs native adapters). +3. Ensure MCP/tool-loop/correction behavior is unchanged. + +### 2. Adapter unit tests + +Per adapter: + +1. request mapping tests +2. response parsing tests +3. error mapping tests +4. usage extraction tests +5. retry behavior tests +6. adaptive throttling behavior tests (drop on 429, gradual recovery) + +Tools: + +1. `pytest-httpx` for HTTP adapters + +### 3. Integration smoke tests + +Optional but recommended: + +1. provider-backed smoke tests controlled by env vars +2. run outside CI by default + +### 4. Health check tests + +Ensure `ModelRegistry.run_health_check()` still behaves correctly for: + +1. chat models +2. embedding models +3. image models + +### 5. Sync/async throttle parity tests + +1. Shared throttle state is enforced across mixed sync and async calls for same key. +2. `max_parallel_requests` is never exceeded under concurrent load. +3. `Retry-After` is respected by both sync and async wrappers. +4. Two aliases pointing to same provider/model maintain independent primary limits. +5. Optional shared upstream pressure signal propagates cooldown correctly across aliases when enabled. + +## Migration Phases and Deliverables + +### Phase 0: Baseline and abstraction setup + +Deliverables: + +1. `clients/types.py`, `clients/base.py`, `clients/errors.py` +2. `LiteLLMBridgeClient` adapter +3. no behavior changes expected + +Exit criteria: + +1. current tests pass with bridge client enabled + +### Phase 1: OpenAI-compatible native adapter + +Deliverables: + +1. `OpenAICompatibleClient` sync/async methods +2. shared retry/transport modules +3. shared adaptive throttle manager with sync/async wrappers +4. `ModelFacade` consumes canonical responses + +Exit criteria: + +1. parity on facade contract tests +2. health checks pass with OpenAI-compatible providers +3. adaptive throttling tests pass for sync and async + +### Phase 2: Anthropic adapter + +Deliverables: + +1. `AnthropicClient` chat + tool use +2. unsupported capability handling for unavailable operations + +Exit criteria: + +1. adapter unit tests pass +2. contract tests pass for Anthropic-configured chat workloads + +### Phase 3: LiteLLM deprecation and removal (non-Bedrock paths) + +Deliverables: + +1. remove `litellm_overrides.py` usage path +2. remove LiteLLM dependency from engine package +3. clean up lazy imports and docs + +Exit criteria: + +1. no runtime import path to LiteLLM +2. full test suite green + +## Operational Guardrails + +1. Feature flag backend switch during migration: + - `DATA_DESIGNER_MODEL_BACKEND=litellm_bridge|native` +2. Log provider, model, latency, retry_count per request for observability. +3. Keep raw provider response in debug logs only, with PII-safe handling. +4. Preserve timeout behavior and ensure async cancellation works cleanly. + +## Risks and Mitigations + +### Risk: subtle response-shape mismatch breaks MCP tool loop + +Mitigation: + +1. strict canonical response tests around `tool_calls` shape +2. reuse existing tool-loop tests unchanged + +### Risk: usage reporting regressions + +Mitigation: + +1. make usage optional and track request success/failure even without tokens +2. add regression tests for `ModelUsageStats` updates per operation + +### Risk: provider-specific capability confusion + +Mitigation: + +1. explicit adapter capability methods +2. actionable unsupported-capability errors that name provider/model/operation + +### Risk: retry policy drift + +Mitigation: + +1. central retry module with deterministic tests +2. preserve current defaults from `LiteLLMRouterDefaultKwargs` + +### Risk: throttle oscillation or starvation under bursty load + +Mitigation: + +1. bound decreases with minimum limit of `1` +2. additive recovery with tunable success window +3. optional smoothing and per-key metrics (`current_limit`, `429_rate`, queue wait time) + +## Proposed Initial Task Breakdown (Implementation Tickets) + +1. Create `clients/` package with canonical types, base protocols, canonical errors. +2. Implement `LiteLLMBridgeClient` and switch `ModelFacade` to use `ModelClient`. +3. Add provider-specific `auth` schema parsing with compatibility fallback from `api_key`. +4. Refactor `errors.py` to consume canonical provider errors. +5. Implement shared adaptive throttle manager keyed by `(provider_name, model_identifier, throttle_domain)` with sync/async wrappers. +6. Add optional shared upstream pressure signal keyed by `(provider_name, model_identifier)` with domain-aware cooldown propagation. +7. Implement `OpenAICompatibleClient` with sync/async, retry, adaptive throttle, auth headers, and image URL-to-base64 normalization. +8. Add adapter tests and contract parametrization by backend. +9. Implement Anthropic adapter (chat + tools + auth headers). +10. Update CLI provider forms for provider-specific auth input. +11. Cut over default backend to native and run soak period. +12. Remove LiteLLM dependency and legacy overrides for non-Bedrock paths. + +## Definition of Done + +The LiteLLM replacement is complete when all conditions are met: + +1. `ModelFacade` no longer imports or types against LiteLLM. +2. `errors.py` no longer matches on LiteLLM exception classes. +3. Engine package does not depend on LiteLLM. +4. Existing model-facing behavior tests pass against native adapters. +5. OpenAI-compatible and Anthropic adapters are available with documented capability limits. +6. Bedrock work is explicitly deferred to `plans/343/model-facade-overhaul-plan-step-2-bedrock.md`. +7. Documentation and CLI provider guidance are updated. diff --git a/plans/343/model-facade-overhaul-plan-step-2-bedrock.md b/plans/343/model-facade-overhaul-plan-step-2-bedrock.md new file mode 100644 index 000000000..450e222b3 --- /dev/null +++ b/plans/343/model-facade-overhaul-plan-step-2-bedrock.md @@ -0,0 +1,191 @@ +--- +date: 2026-02-19 +authors: + - nmulepati +--- + +# Model Facade Overhaul Plan: Step 2 (Bedrock) + +This step adds native Bedrock support after Step 1 is complete. + +Depends on: + +1. `plans/343/model-facade-overhaul-plan-step-1.md` + +## Scope + +1. Add `BedrockClient` adapter under `engine/models/clients/adapters/bedrock.py`. +2. Add Bedrock auth schema and resolver support. +3. Add Bedrock request/response normalization to canonical model client types. +4. Integrate Bedrock into client factory routing. +5. Add Bedrock-specific tests (unit + integration stubs + optional smoke tests). + +Out of scope: + +1. Re-design of `ModelFacade` public API (already covered by Step 1). +2. Reworking shared retry/throttle abstractions except Bedrock-specific mappings. + +## Adapter Design + +### Factory routing + +1. `provider_type == "bedrock"` -> `BedrockClient` +2. Unknown bedrock auth mode or invalid config -> `ValueError` with actionable message + +### Operations + +1. Chat completion: supported by model-family mapper. +2. Embeddings: supported by model-family mapper. +3. Image generation: supported by model-family mapper. +4. Unsupported operation for chosen model -> canonical `ProviderError(kind=UNSUPPORTED_CAPABILITY)`. + +### Model family mappers + +Implement Bedrock mappers per family: + +1. `bedrock_mappers/claude.py` +2. `bedrock_mappers/llama.py` +3. `bedrock_mappers/nova.py` +4. `bedrock_mappers/titan.py` +5. `bedrock_mappers/stability.py` (if configured) + +Each mapper provides: + +1. canonical request -> Bedrock payload conversion +2. Bedrock response -> canonical response conversion +3. capability declaration per operation + +## Authentication + +Bedrock uses SigV4, not API key headers. + +### Supported auth modes + +1. `default_chain`: environment/profile/role chain +2. `profile`: named AWS profile +3. `access_key`: explicit key/secret/session token via `SecretResolver` +4. `assume_role`: STS role assumption + +### Bedrock auth schema + +```python +from __future__ import annotations + +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + + +class BedrockEnvAuth(BaseModel): + mode: Literal["default_chain"] = "default_chain" + region: str + + +class BedrockProfileAuth(BaseModel): + mode: Literal["profile"] = "profile" + region: str + profile_name: str + + +class BedrockKeyAuth(BaseModel): + mode: Literal["access_key"] = "access_key" + region: str + access_key_id: str + secret_access_key: str + session_token: str | None = None + + +class BedrockAssumeRoleAuth(BaseModel): + mode: Literal["assume_role"] = "assume_role" + region: str + role_arn: str + external_id: str | None = None + session_name: str = "data-designer-bedrock" + + +BedrockAuth = Annotated[ + BedrockEnvAuth | BedrockProfileAuth | BedrockKeyAuth | BedrockAssumeRoleAuth, + Field(discriminator="mode"), +] +``` + +### Auth error mapping + +1. `UnrecognizedClientException` -> `ProviderError(kind=AUTHENTICATION)` +2. `AccessDeniedException` -> `ProviderError(kind=PERMISSION_DENIED)` +3. STS assume-role failures -> `ProviderError(kind=AUTHENTICATION | PERMISSION_DENIED)` based on error code + +## Throttling and Concurrency + +Use the same shared throttling framework introduced in Step 1. + +### Keys + +1. Global cap key: `(provider_name, model_identifier)` +2. Domain key: `(provider_name, model_identifier, throttle_domain)` + +### Domain selection + +1. Bedrock chat/converse calls -> `chat` +2. Bedrock embedding calls -> `embedding` +3. Bedrock image calls -> `image` + +### 429/throttling signals + +1. SDK throttling exceptions (`ThrottlingException` family) +2. retry hints when present in SDK metadata + +## Testing + +### Unit tests + +1. mapper payload conversion tests +2. mapper response parsing tests +3. Bedrock auth mode resolution tests +4. Bedrock error-kind mapping tests +5. throttling signal mapping tests + +### Integration-style tests + +1. `botocore.stub.Stubber` for runtime responses +2. mixed sync/async throttling behavior with shared keys + +### Optional smoke tests + +1. gated by env vars/credentials +2. excluded from default CI + +## Delivery Phases + +### Phase A: Bedrock schema and factory wiring + +1. Add Bedrock auth schema types and validation. +2. Add factory route for `provider_type == "bedrock"`. +3. Add placeholder adapter with unsupported-capability errors. + +### Phase B: Chat completion support + +1. Implement chat mapper(s). +2. Implement sync/async chat calls with canonical response conversion. +3. Add chat contract tests. + +### Phase C: Embeddings and image support + +1. Implement embedding and image mappers. +2. Add capability guards for unsupported model families. +3. Add usage and error parity tests. + +### Phase D: Hardening and rollout + +1. Add retry/throttle tuning for Bedrock exceptions. +2. Run staged rollout and soak tests. +3. Enable by default after parity gates pass. + +## Definition of Done + +1. `provider_type="bedrock"` resolves to `BedrockClient`. +2. Bedrock auth modes validate and resolve secrets correctly. +3. Chat/embedding/image operations are normalized to canonical response types. +4. Bedrock error mapping surfaces existing user-facing `DataDesignerError` classes. +5. Shared adaptive throttling works with Bedrock throttling signals. +6. Test coverage includes mapper logic, auth, errors, and throttling behavior. From 43824ea00edc8e837efccfdf59f74d6ab42c8dc7 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 20 Feb 2026 09:57:44 -0700 Subject: [PATCH 02/27] update plan --- .../343/model-facade-overhaul-plan-step-1.md | 48 ++++++++++- ...del-facade-overhaul-plan-step-2-bedrock.md | 82 +++++++++++++++++++ 2 files changed, 127 insertions(+), 3 deletions(-) diff --git a/plans/343/model-facade-overhaul-plan-step-1.md b/plans/343/model-facade-overhaul-plan-step-1.md index b3523f048..1a2bbf0de 100644 --- a/plans/343/model-facade-overhaul-plan-step-1.md +++ b/plans/343/model-facade-overhaul-plan-step-1.md @@ -24,6 +24,32 @@ Reviewers should validate three things first: 2. Provider-specific concerns are isolated inside adapters (`openai_compatible`, `anthropic`) behind canonical request/response types. 3. Rollout is reversible with feature flag and bridge adapter until parity is proven. +## Locked Design Decisions (Step 1) + +These are explicit decisions for Step 1 review and implementation. + +1. `ModelFacade` public API remains unchanged. +2. Adapter boundary uses canonical request/response types; provider SDK/HTTP shapes do not leak upward. +3. Adaptive throttling uses: + - global cap key: `(provider_name, model_identifier)` + - domain key: `(provider_name, model_identifier, throttle_domain)` +4. If multiple model configs target the same global cap key, effective hard cap is: + - `min(max_parallel_requests across matching model configs)` +5. Throttle domain is derived from actual backend route: + - chat-completions-backed image generation shares `chat` domain. +6. Auth fallback compatibility is retained: + - top-level `api_key` continues to work for `openai` and `anthropic`. +7. Streaming support is out of scope for Step 1. +8. Bedrock support is intentionally out of scope for Step 1. + +## Reviewer Sign-Off Questions + +These should be answered in review before implementation begins: + +1. Is bridge-first migration (`LiteLLMBridgeClient`) acceptable as mandatory Phase 0? +2. Is the minimum-cap rule for shared provider/model concurrency acceptable? +3. Is the proposed feature-flag rollout (`litellm_bridge|native`) sufficient for rollback needs? + ## Architecture Diagram ### 1. Structural view (boundaries and ownership) @@ -101,7 +127,7 @@ Callers 5) On recovery, additive increase restores capacity up to effective max ``` -## Concrete Implementation Plan (Reviewer-Oriented) +## Concrete Implementation Plan ### File-level change map @@ -1090,8 +1116,24 @@ Ensure `ModelRegistry.run_health_check()` still behaves correctly for: 1. Shared throttle state is enforced across mixed sync and async calls for same key. 2. `max_parallel_requests` is never exceeded under concurrent load. 3. `Retry-After` is respected by both sync and async wrappers. -4. Two aliases pointing to same provider/model maintain independent primary limits. -5. Optional shared upstream pressure signal propagates cooldown correctly across aliases when enabled. +4. Two aliases pointing to same provider/model share one global cap whose effective max is the lower configured limit. +5. Domain throttling remains route-aware (`chat`, `embedding`, `image`) under shared global cap. +6. Optional shared upstream pressure signal propagates cooldown correctly across aliases when enabled. + +## Cutover Readiness Gates + +Native backend becomes default only when all gates pass: + +1. Contract tests: + - zero regressions across sync and async `ModelFacade` behavior. +2. Error parity: + - user-facing error classes unchanged for representative failure modes. +3. Throughput stability: + - no sustained degradation in records/sec under standard load profile. +4. Throttling behavior: + - 429 recovery stabilizes without oscillation under stress test profile. +5. Rollback safety: + - feature flag rollback validated in a single release candidate. ## Migration Phases and Deliverables diff --git a/plans/343/model-facade-overhaul-plan-step-2-bedrock.md b/plans/343/model-facade-overhaul-plan-step-2-bedrock.md index 450e222b3..8e7f1d5f2 100644 --- a/plans/343/model-facade-overhaul-plan-step-2-bedrock.md +++ b/plans/343/model-facade-overhaul-plan-step-2-bedrock.md @@ -12,6 +12,47 @@ Depends on: 1. `plans/343/model-facade-overhaul-plan-step-1.md` +## Reviewer Snapshot + +Reviewers should verify: + +1. Step 2 reuses Step 1 abstractions (client boundary, retry, throttling, error model) without forking patterns. +2. Bedrock-specific logic remains isolated to adapter/mappers/auth resolution. +3. Capability gating is explicit and fails early when an operation is unsupported for a model family. + +## Entry Criteria + +Step 2 starts only when these Step 1 conditions are met: + +1. `ModelFacade` is fully backed by `ModelClient` abstraction. +2. Shared retry and adaptive throttle modules are in production. +3. OpenAI-compatible + Anthropic parity gates are complete. +4. Feature-flag-based rollback path is validated. + +## Architecture Delta Diagram + +Step 2 extends the Step 1 client layer by adding Bedrock-specific paths only. + +```text + Step 1 (existing) +ModelFacade -> ModelClient API -> [OpenAI Adapter | Anthropic Adapter | Bridge] + + Step 2 (new) +ModelFacade -> ModelClient API -> [OpenAI | Anthropic | Bedrock | Bridge?] + | + v + Bedrock Mapper Layer + (claude/llama/nova/titan/stability) + | + v + AWS Bedrock Runtime API +``` + +Design intent: + +1. No `ModelFacade` API changes in Step 2. +2. Bedrock complexity is contained under adapter + mapper + auth resolution. + ## Scope 1. Add `BedrockClient` adapter under `engine/models/clients/adapters/bedrock.py`. @@ -25,6 +66,24 @@ Out of scope: 1. Re-design of `ModelFacade` public API (already covered by Step 1). 2. Reworking shared retry/throttle abstractions except Bedrock-specific mappings. +## File-Level Change Map (Step 2) + +New files: + +1. `packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/bedrock.py` +2. `packages/data-designer-engine/src/data_designer/engine/models/clients/bedrock_mappers/claude.py` +3. `packages/data-designer-engine/src/data_designer/engine/models/clients/bedrock_mappers/llama.py` +4. `packages/data-designer-engine/src/data_designer/engine/models/clients/bedrock_mappers/nova.py` +5. `packages/data-designer-engine/src/data_designer/engine/models/clients/bedrock_mappers/titan.py` +6. `packages/data-designer-engine/src/data_designer/engine/models/clients/bedrock_mappers/stability.py` (if enabled) + +Updated files: + +1. `packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py` (route `provider_type=bedrock`) +2. `packages/data-designer-config/src/data_designer/config/models.py` (Bedrock auth schema) +3. `packages/data-designer/src/data_designer/cli/forms/provider_builder.py` (Bedrock auth input UX) +4. `packages/data-designer-engine/src/data_designer/engine/models/errors.py` (Bedrock-specific error normalization coverage) + ## Adapter Design ### Factory routing @@ -189,3 +248,26 @@ Use the same shared throttling framework introduced in Step 1. 4. Bedrock error mapping surfaces existing user-facing `DataDesignerError` classes. 5. Shared adaptive throttling works with Bedrock throttling signals. 6. Test coverage includes mapper logic, auth, errors, and throttling behavior. + +## Risks and Mitigations + +### Risk: model-family payload fragmentation + +Mitigation: + +1. strict per-family mapper contracts with canonical input/output tests +2. fail-fast unsupported-capability checks before outbound call + +### Risk: AWS auth misconfiguration complexity + +Mitigation: + +1. explicit auth-mode validation and actionable errors +2. CI stub tests for all auth modes + optional smoke tests in controlled environment + +### Risk: throttling differences vs HTTP providers + +Mitigation: + +1. explicit mapping from Bedrock throttling exceptions to canonical `RATE_LIMIT` +2. stress tests validating AIMD behavior under SDK exception patterns From 2a5f1e4b98c9b005f3e1af0fb6a206bb20d09c57 Mon Sep 17 00:00:00 2001 From: Johnny Greco Date: Fri, 20 Feb 2026 12:05:25 -0500 Subject: [PATCH 03/27] add review --- .../343/_review-model-facade-overhaul-plan.md | 103 ++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 plans/343/_review-model-facade-overhaul-plan.md diff --git a/plans/343/_review-model-facade-overhaul-plan.md b/plans/343/_review-model-facade-overhaul-plan.md new file mode 100644 index 000000000..3f4a22418 --- /dev/null +++ b/plans/343/_review-model-facade-overhaul-plan.md @@ -0,0 +1,103 @@ +--- +date: 2026-02-20 +reviewer: codex +branch: nm/overhaul-model-facade-guts +sources: + - plans/343/model-facade-overhaul-plan-review-codex.md + - plans/343/model-facade-overhaul-plan-review-opus.md +scope: + - plans/343/model-facade-overhaul-plan-step-1.md + - plans/343/model-facade-overhaul-plan-step-2-bedrock.md +--- + +# Aggregated Plan Review (Agreed Feedback) + +## Verdict + +Request changes. The migration direction is solid, but several contradictions and missing contracts should be resolved before implementation. + +## Findings + +### HIGH: `completion`/`acompletion` response contract is not explicit enough for MCP compatibility + +Evidence: +- Step 1 says keep methods/signatures and preserve MCP loop unchanged (`plans/343/model-facade-overhaul-plan-step-1.md:924`, `plans/343/model-facade-overhaul-plan-step-1.md:928`). +- Step 1 also says ModelFacade will consume canonical response shapes (`plans/343/model-facade-overhaul-plan-step-1.md:927`) and maps LiteLLM fields into canonical message fields (`plans/343/model-facade-overhaul-plan-step-1.md:703`). +- Current MCP shape parity risk is explicitly called out (`plans/343/model-facade-overhaul-plan-step-1.md:1160`). + +Recommendation: +- Pick one explicit migration contract: + 1. Refactor ModelFacade + MCP helpers to canonical response shape in PR-2, or + 2. Keep LiteLLM-compatible response shape through bridge phase. +- Add dedicated MCP parity tests (tool-call extraction/refusal/reasoning content) for the chosen contract. + +### HIGH: Adaptive throttling contract is internally inconsistent + +Evidence: +- Shared hard cap is defined as `min(max_parallel_requests...)` across aliases (`plans/343/model-facade-overhaul-plan-step-1.md:838`, `plans/343/model-facade-overhaul-plan-step-1.md:839`). +- Test plan also expects aliases on same provider/model to keep independent primary limits (`plans/343/model-facade-overhaul-plan-step-1.md:1093`). + +Recommendation: +- Choose one contract and align tests: + 1. Shared hard cap across aliases (safest), or + 2. Per-alias limits with optional shared pressure signal. +- Remove contradictory assertions from the parity suite. + +### HIGH: Step 1 provider-type hardening conflicts with Step 2 Bedrock routing + +Evidence: +- Step 1 Phase B constrains provider type enum to `openai` and `anthropic` (`plans/343/model-facade-overhaul-plan-step-1.md:1038`). +- Step 2 requires `provider_type == "bedrock"` factory routing (`plans/343/model-facade-overhaul-plan-step-2-bedrock.md:32`). + +Recommendation: +- Keep `provider_type` extensible through Step 1 (or include `bedrock` in planned enum values). +- Defer strict enum narrowing until after Bedrock support lands. + +### HIGH: Rollback safety is inconsistent with default-flip/removal sequencing + +Evidence: +- PR slicing combines cutover default flip and LiteLLM removal in PR-6 (`plans/343/model-facade-overhaul-plan-step-1.md:136`). +- Rollback guardrail promises backend flag toggle (`plans/343/model-facade-overhaul-plan-step-1.md:1152`). +- Phase 3 removes LiteLLM runtime path/dependency (`plans/343/model-facade-overhaul-plan-step-1.md:1141`, `plans/343/model-facade-overhaul-plan-step-1.md:1142`). + +Recommendation: +- Separate default flip from dependency/path removal. +- Keep bridge rollback path for at least one soak/release window after native default. + +### MEDIUM: Auth error mapping is ambiguous for `401/403` + +Evidence: +- Step 1 maps `401/403` to `AUTHENTICATION | PERMISSION_DENIED` with no deterministic rule (`plans/343/model-facade-overhaul-plan-step-1.md:508`, `plans/343/model-facade-overhaul-plan-step-1.md:509`). +- Step 2 has same ambiguity for STS failures (`plans/343/model-facade-overhaul-plan-step-2-bedrock.md:116`). + +Recommendation: +- Define deterministic status/code mapping (default `401 -> AUTHENTICATION`, `403 -> PERMISSION_DENIED`) and document provider-specific exceptions. +- Add explicit parity tests for this matrix. + +### MEDIUM: HTTP client lifecycle and pool sizing are underspecified + +Evidence: +- HTTP adapter skeleton instantiates both sync/async httpx clients and exposes `close`/`aclose` (`plans/343/model-facade-overhaul-plan-step-1.md:583`, `plans/343/model-facade-overhaul-plan-step-1.md:584`, `plans/343/model-facade-overhaul-plan-step-1.md:594`, `plans/343/model-facade-overhaul-plan-step-1.md:597`). +- Plan does not define owner/teardown integration or pool sizing policy. + +Recommendation: +- Add lifecycle section specifying creation/ownership/teardown (factory, registry, or facade shutdown hook). +- Define connection pool sizing relative to concurrency settings. + +### MEDIUM: `extra_body`/`extra_headers` precedence needs explicit contract + +Evidence: +- Plan promises no precedence drift but does not state merge order (`plans/343/model-facade-overhaul-plan-step-1.md:975`, `plans/343/model-facade-overhaul-plan-step-1.md:1014`, `plans/343/model-facade-overhaul-plan-step-1.md:1015`). +- Existing implementation has specific precedence rules (provider overrides for `extra_body`, provider replacement for `extra_headers`). + +Recommendation: +- Document exact merge precedence and add regression tests for it. + +### MEDIUM: Anthropic capability statements conflict within Step 1 + +Evidence: +- Anthropic adapter skeleton marks embeddings and image generation unsupported (`plans/343/model-facade-overhaul-plan-step-1.md:656`, `plans/343/model-facade-overhaul-plan-step-1.md:660`, `plans/343/model-facade-overhaul-plan-step-1.md:674`, `plans/343/model-facade-overhaul-plan-step-1.md:680`). +- Capability matrix says Anthropic support is "Provider dependent" for those operations (`plans/343/model-facade-overhaul-plan-step-1.md:946`, `plans/343/model-facade-overhaul-plan-step-1.md:947`, `plans/343/model-facade-overhaul-plan-step-1.md:948`). + +Recommendation: +- Align matrix with Step 1 implementation scope, or define exact conditions and gates for conditional support. From f945d5b3ccffce1e112f4f18147abec4d4351306 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 20 Feb 2026 11:03:39 -0700 Subject: [PATCH 04/27] address feedback + add more details after several self reviews --- .../343/model-facade-overhaul-plan-step-1.md | 286 +++++++++++++++--- ...del-facade-overhaul-plan-step-2-bedrock.md | 21 +- 2 files changed, 265 insertions(+), 42 deletions(-) diff --git a/plans/343/model-facade-overhaul-plan-step-1.md b/plans/343/model-facade-overhaul-plan-step-1.md index 1a2bbf0de..402dcce73 100644 --- a/plans/343/model-facade-overhaul-plan-step-1.md +++ b/plans/343/model-facade-overhaul-plan-step-1.md @@ -6,6 +6,8 @@ authors: # Model Facade Overhaul Plan: Step 1 (Non-Bedrock) +Review Reference: `plans/343/_review-model-facade-overhaul-plan.md` + This document proposes a concrete migration plan to replace LiteLLM in Data Designer while keeping the public behavior of `ModelFacade` stable. The short version: @@ -24,6 +26,21 @@ Reviewers should validate three things first: 2. Provider-specific concerns are isolated inside adapters (`openai_compatible`, `anthropic`) behind canonical request/response types. 3. Rollout is reversible with feature flag and bridge adapter until parity is proven. +## Feedback Closure Matrix + +This section maps aggregated reviewer findings to the concrete plan updates. + +| Reviewer finding | Resolution in this plan | +|---|---| +| MCP contract for `completion`/`acompletion` is ambiguous | `Explicit MCP Compatibility Contract` defines canonical response contract and required parity tests | +| Adaptive throttling contract had contradictions | `Adaptive Throttling` + `Sync/async throttle parity tests` now align on shared global cap + domain keys | +| Step 1 provider type hardening could block Step 2 Bedrock | `Config and CLI Evolution` keeps provider type extensible and reserves `bedrock` for Step 2 | +| Rollback safety conflicted with flip/removal sequencing | `PR slicing` and `Migration Phases` split default flip/soak from LiteLLM removal | +| Auth mapping for `401/403` ambiguous | `Auth error normalization` now defines deterministic default mapping | +| HTTP client lifecycle/pooling underspecified | `HTTP client lifecycle and pool policy` adds ownership/teardown/pool sizing contract | +| `extra_body`/`extra_headers` precedence unclear | `Request merge / precedence contract` defines explicit merge order | +| Anthropic capability statements inconsistent | `Capability Matrix` now matches Step 1 implementation scope (`No` for embeddings/image) | + ## Locked Design Decisions (Step 1) These are explicit decisions for Step 1 review and implementation. @@ -41,6 +58,8 @@ These are explicit decisions for Step 1 review and implementation. - top-level `api_key` continues to work for `openai` and `anthropic`. 7. Streaming support is out of scope for Step 1. 8. Bedrock support is intentionally out of scope for Step 1. +9. `completion`/`acompletion` contract in Step 1 is canonical response shape. +10. MCP handling is adapted in Step 1 to consume canonical responses; no long-term LiteLLM-shaped response dependency. ## Reviewer Sign-Off Questions @@ -127,6 +146,22 @@ Callers 5) On recovery, additive increase restores capacity up to effective max ``` +## Explicit MCP Compatibility Contract + +Step 1 chooses this migration contract: + +1. `ModelFacade` is refactored to consume canonical `ChatCompletionResponse` in PR-2. +2. MCP helpers (`has_tool_calls`, `tool_call_count`, processing/refusal path) are updated for canonical tool-call fields in the same PR. +3. Bridge adapter translates LiteLLM responses into canonical shape before `ModelFacade` sees them. +4. This contract supersedes any phrasing that MCP helpers stay interface-identical; behavior parity is preserved, interfaces are updated. + +Required parity tests: + +1. tool-call extraction count parity +2. refusal flow parity when tool budget is exceeded +3. reasoning content propagation parity +4. trace message shape parity for assistant/tool messages + ## Concrete Implementation Plan ### File-level change map @@ -148,18 +183,39 @@ Updated files (Step 1): 1. `packages/data-designer-engine/src/data_designer/engine/models/facade.py` 2. `packages/data-designer-engine/src/data_designer/engine/models/errors.py` 3. `packages/data-designer-engine/src/data_designer/engine/models/factory.py` -4. `packages/data-designer-config/src/data_designer/config/models.py` (auth schema extension) -5. `packages/data-designer/src/data_designer/cli/forms/provider_builder.py` (provider-specific auth input) -6. `packages/data-designer-config/src/data_designer/lazy_heavy_imports.py` (remove `litellm` after cutover) +4. `packages/data-designer-engine/src/data_designer/engine/models/registry.py` (adapter lifecycle close/aclose ownership) +5. `packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py` (shutdown wiring if needed) +6. `packages/data-designer/src/data_designer/interface/data_designer.py` (invoke resource teardown hooks in generation entrypoints) +7. `packages/data-designer-config/src/data_designer/config/models.py` (auth schema extension) +8. `packages/data-designer/src/data_designer/cli/forms/provider_builder.py` (provider-specific auth input) +9. `packages/data-designer-config/src/data_designer/lazy_heavy_imports.py` (remove `litellm` after cutover) ### PR slicing (recommended) 1. PR-1: canonical types/interfaces/errors + bridge adapter + no behavior change. -2. PR-2: `ModelFacade` switched to `ModelClient` + parity tests passing on bridge. -3. PR-3: OpenAI-compatible adapter + retry + throttle + auth integration. + - files: `clients/base.py`, `clients/types.py`, `clients/errors.py`, `clients/adapters/litellm_bridge.py` + - docs: add architecture notes for canonical adapter boundary and bridge purpose. +2. PR-2: `ModelFacade` switched to `ModelClient` + lifecycle wiring + parity tests on bridge. + - files: `models/facade.py`, `models/errors.py`, `models/factory.py`, `clients/factory.py`, `models/registry.py`, `resources/resource_provider.py`, `interface/data_designer.py` + - docs: update internal lifecycle/ownership docs for adapter teardown and resource shutdown behavior. +3. PR-3: OpenAI-compatible adapter + shared retry/throttle + auth integration. + - files: `clients/retry.py`, `clients/throttle.py`, `clients/adapters/openai_compatible.py` + - docs: add provider docs for openai-compatible routing, endpoint expectations, and retry/throttle semantics. 4. PR-4: Anthropic adapter + auth integration + capability gating. -5. PR-5: CLI/config schema updates + docs + migration guards. -6. PR-6: Cutover flag default flip + LiteLLM removal for Step 1 scope. + - files: `clients/adapters/anthropic.py` + - docs: add Anthropic capability/limitations documentation for Step 1 scope. +5. PR-5: Config/CLI auth schema rollout + migration guards + docs. + - files: `config/models.py`, `cli/forms/provider_builder.py` + - docs: publish auth schema migration guide (legacy `api_key` fallback + typed `auth` objects) and CLI examples. +6. PR-6: Cutover flag default flip to native while retaining bridge path. + - docs: update rollout runbook and env-flag guidance (`DATA_DESIGNER_MODEL_BACKEND`) for operators. +7. PR-7: Remove LiteLLM dependency/path after soak window. + - files: `lazy_heavy_imports.py` and removal of legacy LiteLLM runtime path + - docs: remove LiteLLM references and close out migration notes. + +### PR coverage check (Step 1) + +Every file listed in `File-level change map` must map to exactly one PR above. If a PR changes additional files, they should be explicitly scoped as tests/docs only. ### Reviewer checklist per PR @@ -168,7 +224,8 @@ Updated files (Step 1): 3. Are sync and async paths symmetric in behavior? 4. Does adaptive throttling honor global cap and domain key rules? 5. Is any secret material exposed in logs or reprs? -6. Is rollback possible via feature flag in the same PR? +6. Is rollback possible via feature flag with bridge path retained during soak? +7. Are adapter lifecycle teardown hooks wired (`ModelRegistry`/`ResourceProvider`) with no leaked clients in tests? ## Why This Plan @@ -249,6 +306,7 @@ class Usage: input_tokens: int | None = None output_tokens: int | None = None total_tokens: int | None = None + generated_images: int | None = None @dataclass @@ -335,6 +393,8 @@ Notes: 1. `raw` exists for diagnostics/logging only. 2. Canonical image output is always base64 payload. 3. Tool calls are normalized to `id/name/arguments_json`. +4. `Usage` includes non-token fields when providers expose them (for example `generated_images`). +5. For image generation, if provider usage does not include image counts, `ModelFacade` tracks `generated_images` from `len(images)` to preserve current `image_usage.total_images` behavior. ## Adapter Interfaces @@ -531,8 +591,11 @@ At model client creation time: Map provider auth failures into canonical errors: -1. OpenAI-compatible `401/403` -> `ProviderError(kind=AUTHENTICATION | PERMISSION_DENIED)` -2. Anthropic `401/403` -> same mapping +1. Default mapping: + - `401` -> `ProviderError(kind=AUTHENTICATION)` + - `403` -> `ProviderError(kind=PERMISSION_DENIED)` +2. OpenAI-compatible: follow default mapping unless provider-specific payload indicates otherwise. +3. Anthropic: follow default mapping unless provider-specific payload indicates otherwise. Then map canonical provider errors to existing Data Designer user-facing errors: @@ -577,7 +640,7 @@ from typing import Any import httpx -from data_designer.engine.models.clients.errors import ProviderError +from data_designer.engine.models.clients.errors import ProviderError, map_http_error_to_provider_error from data_designer.engine.models.clients.retry import RetryPolicy, run_with_retries from data_designer.engine.models.clients.types import ( ChatCompletionRequest, @@ -590,6 +653,8 @@ from data_designer.engine.models.clients.types import ( class HTTPAdapterBase: + ROUTES: dict[str, str] = {} + def __init__( self, *, @@ -617,6 +682,43 @@ class HTTPAdapterBase: headers.update(extra_headers) return headers + def _resolve_url(self, route_key: str) -> str: + try: + route_path = self.ROUTES[route_key] + except KeyError as exc: + raise ValueError(f"Unknown route key {route_key!r} for provider {self.provider_name!r}") from exc + return f"{self.endpoint}/{route_path.lstrip('/')}" + + def _post_json( + self, + route_key: str, + payload: dict[str, Any], + extra_headers: dict[str, str] | None = None, + ) -> dict[str, Any]: + response = self._client.post( + self._resolve_url(route_key), + json=payload, + headers=self._headers(extra_headers), + ) + if response.status_code >= 400: + raise map_http_error_to_provider_error(response=response, provider_name=self.provider_name) + return response.json() + + async def _apost_json( + self, + route_key: str, + payload: dict[str, Any], + extra_headers: dict[str, str] | None = None, + ) -> dict[str, Any]: + response = await self._aclient.post( + self._resolve_url(route_key), + json=payload, + headers=self._headers(extra_headers), + ) + if response.status_code >= 400: + raise map_http_error_to_provider_error(response=response, provider_name=self.provider_name) + return response.json() + def close(self) -> None: self._client.close() @@ -624,10 +726,35 @@ class HTTPAdapterBase: await self._aclient.aclose() ``` +### HTTP client lifecycle and pool policy + +Lifecycle contract: + +1. Adapters are created by model client factory and owned by the model registry lifetime. +2. Model registry shutdown is responsible for invoking `close`/`aclose` on all adapter instances. +3. `ResourceProvider` exposes `close`/`aclose` and delegates to `ModelRegistry` teardown. +4. `DataDesigner` entrypoints (`create`, `preview`, `validate`) invoke resource teardown in `finally` blocks. +5. Tests must verify no leaked open HTTP clients after teardown. + +Pool sizing policy: + +1. Configure `httpx` limits using effective concurrency with concrete defaults: + - `max_connections = max(32, 2 * effective_max_parallel_requests)` + - `max_keepalive_connections = max(16, effective_max_parallel_requests)` +2. Keep sync and async client limits aligned. +3. Revisit limits per provider if transport characteristics require overrides. +4. Pool limits are derived from the shared effective max cap at client creation time; AIMD adjusts request admission, not socket pool size. + ### OpenAI-compatible adapter skeleton ```python class OpenAICompatibleClient(HTTPAdapterBase, ModelClient): + ROUTES = { + "chat": "/chat/completions", + "embedding": "/embeddings", + "image": "/images/generations", + } + def supports_chat_completion(self) -> bool: return True @@ -651,7 +778,7 @@ class OpenAICompatibleClient(HTTPAdapterBase, ModelClient): payload.update(request.extra_body) response_json = run_with_retries( - fn=lambda: self._post_json("/chat/completions", payload, request.extra_headers), + fn=lambda: self._post_json("chat", payload, request.extra_headers), policy=self.retry_policy, ) return parse_openai_chat_response(response_json) @@ -660,13 +787,40 @@ class OpenAICompatibleClient(HTTPAdapterBase, ModelClient): ... def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - ... + payload = {"model": request.model, "input": request.inputs} + if request.extra_body: + payload.update(request.extra_body) + response_json = run_with_retries( + fn=lambda: self._post_json("embedding", payload, request.extra_headers), + policy=self.retry_policy, + ) + return parse_openai_embedding_response(response_json) async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: ... def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - ... + # Autoregressive image models may use chat route; diffusion-style models use image route. + if request.messages: + route_key = "chat" + payload = openai_chat_image_payload_from_canonical(request) + else: + route_key = "image" + payload = { + "model": request.model, + "prompt": request.prompt, + "n": request.n, + } + payload = {k: v for k, v in payload.items() if v is not None} + + if request.extra_body: + payload.update(request.extra_body) + + response_json = run_with_retries( + fn=lambda: self._post_json(route_key, payload, request.extra_headers), + policy=self.retry_policy, + ) + return parse_openai_image_response(response_json) async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: ... @@ -676,6 +830,10 @@ class OpenAICompatibleClient(HTTPAdapterBase, ModelClient): ```python class AnthropicClient(HTTPAdapterBase, ModelClient): + ROUTES = { + "chat": "/v1/messages", + } + def supports_chat_completion(self) -> bool: return True @@ -688,7 +846,7 @@ class AnthropicClient(HTTPAdapterBase, ModelClient): def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: payload = anthropic_payload_from_canonical(request) response_json = run_with_retries( - fn=lambda: self._post_json("/v1/messages", payload, anthropic_headers(request.extra_headers)), + fn=lambda: self._post_json("chat", payload, anthropic_headers(request.extra_headers)), policy=self.retry_policy, ) return parse_anthropic_chat_response(response_json) @@ -711,6 +869,16 @@ class AnthropicClient(HTTPAdapterBase, ModelClient): ## Request and Response Mapping Details +### Request merge / precedence contract + +To preserve current behavior, merge precedence is explicit: + +1. Start from model inference params (`generate_kwargs`). +2. Overlay per-call kwargs. +3. Merge `extra_body` with provider extra body taking precedence on key conflicts. +4. Set `extra_headers` from provider extra headers (provider-level replacement semantics). +5. Drop non-provider params like `purpose` before outbound request. + ### Canonical -> OpenAI-compatible chat payload | Canonical field | OpenAI payload field | @@ -770,7 +938,9 @@ Routing logic: 1. `provider_type == "openai"` -> `OpenAICompatibleClient` 2. `provider_type == "anthropic"` -> `AnthropicClient` -3. unknown -> `ValueError` with supported provider types +3. `provider_type == "bedrock"` -> fail fast with: + - `ValueError("provider_type='bedrock' is deferred to Step 2; see plans/343/model-facade-overhaul-plan-step-2-bedrock.md")` +4. unknown -> `ValueError` with supported provider types Migration safety option: @@ -951,7 +1121,7 @@ Minimal diff approach: 2. Replace `_router` with `_client: ModelClient`. 3. Convert `ChatMessage` list into canonical request. 4. Consume canonical response shape. -5. Preserve all MCP logic exactly as-is. +5. Preserve MCP generation semantics (tool budget, refusal behavior, corrections, trace behavior) while updating MCP helper interfaces to canonical response fields. 6. Move usage tracking methods to consume canonical `Usage`. Expected code updates: @@ -960,6 +1130,7 @@ Expected code updates: 2. `factory.py`: initialize client factory instead of applying LiteLLM patches. 3. `errors.py`: map from canonical `ProviderError` instead of LiteLLM exception classes. 4. `lazy_heavy_imports.py`: remove `litellm` entry after complete cutover. +5. `registry.py` + `resource_provider.py` + `data_designer.py`: add deterministic client teardown hooks for sync/async lifecycles. ## Capability Matrix @@ -969,9 +1140,9 @@ Introduce explicit capability checks at adapter and model level. |---|---|---| | Chat completion | Yes | Yes | | Tool calls | Yes | Yes | -| Embeddings | Yes (endpoint/model dependent) | Provider dependent | -| Image generation (diffusion endpoint) | Yes (provider dependent) | Provider dependent | -| Image via chat completion | Some models | Provider dependent | +| Embeddings | Yes (endpoint/model dependent) | No (Step 1) | +| Image generation (diffusion endpoint) | Yes (provider dependent) | No (Step 1) | +| Image via chat completion | Some models | No (Step 1) | If unsupported at runtime: @@ -1028,13 +1199,13 @@ This matrix defines expected parity between the current LiteLLM-backed implement | Token usage for embeddings | From LiteLLM usage fields | From canonical `Usage` parsed by adapter | Same behavior | | Token usage for image generation | From LiteLLM image usage (when present) | From canonical `Usage` parsed by adapter | Same behavior; request counts still update if token usage absent | | Tool usage tracking | Managed in `generate/agenerate` loops | Unchanged in `ModelFacade` | Exact parity expected | -| Image count usage | Counted from returned images | Counted from canonical image payloads | Exact parity expected | +| Image count usage | Counted from returned images | Counted from `usage.generated_images` when provided, otherwise from canonical image payload count | Exact parity expected | ### Configuration compatibility | Config surface | Current | Target | Compatibility expectation | |---|---|---|---| -| `ModelProvider.provider_type` | Free-form string | Enumerated known values in factory (later schema hardening) | Existing `"openai"` continues to work unchanged | +| `ModelProvider.provider_type` | Free-form string | Extensible string + known-values validation | Existing `"openai"` continues to work unchanged; `bedrock` remains reserved for Step 2 | | `ModelProvider.api_key` | Top-level optional field | Supported as fallback for auth | Backward compatible during migration window | | `ModelProvider.auth` | Not present | Optional provider-specific auth object | Additive, non-breaking introduction | | `extra_headers` | Provider-level dict | Preserved and merged into adapter request | Same precedence behavior | @@ -1059,12 +1230,11 @@ Current `ModelProvider.provider_type` is free-form string. Tighten this in phase 1. Keep `str` but validate known values in factory. 2. Emit warning for unknown values. -### Phase B (breaking-ready) +### Phase B (post-Step2 hardening) -1. Migrate to a constrained literal/enum: - - `"openai"` - - `"anthropic"` -2. Update CLI provider form with controlled options and validation. +1. Keep provider type extensible in Step 1 while validating known values (`openai`, `anthropic`, `bedrock` reserved). +2. Defer strict enum hardening until after Step 2 Bedrock delivery. +3. Update CLI provider form with controlled options and validation. Related files: @@ -1080,6 +1250,11 @@ Goal: prove behavior parity independent of backend. 1. Keep existing `test_facade.py` behavior assertions. 2. Parametrize backend selection (`litellm_bridge` vs native adapters). 3. Ensure MCP/tool-loop/correction behavior is unchanged. +4. Add explicit MCP parity tests for: + - tool-call extraction count + - refusal path + - reasoning content propagation + - trace message shape ### 2. Adapter unit tests @@ -1091,6 +1266,7 @@ Per adapter: 4. usage extraction tests 5. retry behavior tests 6. adaptive throttling behavior tests (drop on 429, gradual recovery) +7. auth status mapping tests (`401 -> AUTHENTICATION`, `403 -> PERMISSION_DENIED`) Tools: @@ -1176,7 +1352,33 @@ Exit criteria: 1. adapter unit tests pass 2. contract tests pass for Anthropic-configured chat workloads -### Phase 3: LiteLLM deprecation and removal (non-Bedrock paths) +### Phase 2b: Config and CLI auth schema rollout + +Deliverables: + +1. typed provider auth schema in `config/models.py` with backward-compatible `api_key` fallback +2. provider-specific auth input flow in `cli/forms/provider_builder.py` +3. migration docs and validation guards for legacy configs + +Exit criteria: + +1. config validation passes for legacy and typed auth examples +2. CLI form tests cover openai/anthropic auth input paths + +### Phase 3: Native default flip + soak window + +Deliverables: + +1. flip default backend to native +2. retain bridge path for rollback +3. run soak window and monitor readiness gates for at least one release window + +Exit criteria: + +1. rollback switch validated in release candidate +2. soak window passes with no blocker regressions + +### Phase 4: LiteLLM deprecation and removal (non-Bedrock paths) Deliverables: @@ -1193,9 +1395,10 @@ Exit criteria: 1. Feature flag backend switch during migration: - `DATA_DESIGNER_MODEL_BACKEND=litellm_bridge|native` -2. Log provider, model, latency, retry_count per request for observability. -3. Keep raw provider response in debug logs only, with PII-safe handling. -4. Preserve timeout behavior and ensure async cancellation works cleanly. +2. Keep bridge path available for rollback until soak/release window completes. +3. Log provider, model, latency, retry_count per request for observability. +4. Keep raw provider response in debug logs only, with PII-safe handling. +5. Preserve timeout behavior and ensure async cancellation works cleanly. ## Risks and Mitigations @@ -1241,14 +1444,16 @@ Mitigation: 2. Implement `LiteLLMBridgeClient` and switch `ModelFacade` to use `ModelClient`. 3. Add provider-specific `auth` schema parsing with compatibility fallback from `api_key`. 4. Refactor `errors.py` to consume canonical provider errors. -5. Implement shared adaptive throttle manager keyed by `(provider_name, model_identifier, throttle_domain)` with sync/async wrappers. -6. Add optional shared upstream pressure signal keyed by `(provider_name, model_identifier)` with domain-aware cooldown propagation. -7. Implement `OpenAICompatibleClient` with sync/async, retry, adaptive throttle, auth headers, and image URL-to-base64 normalization. -8. Add adapter tests and contract parametrization by backend. -9. Implement Anthropic adapter (chat + tools + auth headers). -10. Update CLI provider forms for provider-specific auth input. -11. Cut over default backend to native and run soak period. -12. Remove LiteLLM dependency and legacy overrides for non-Bedrock paths. +5. Add `ModelRegistry.close/aclose` and `ResourceProvider.close/aclose`, and wire teardown in `DataDesigner` entrypoints. +6. Implement shared adaptive throttle manager keyed by `(provider_name, model_identifier, throttle_domain)` with sync/async wrappers. +7. Add optional shared upstream pressure signal keyed by `(provider_name, model_identifier)` with domain-aware cooldown propagation. +8. Implement `OpenAICompatibleClient` with sync/async, retry, adaptive throttle, auth headers, and image URL-to-base64 normalization. +9. Add adapter tests and contract parametrization by backend. +10. Implement Anthropic adapter (chat + tools + auth headers). +11. Update CLI provider forms for provider-specific auth input. +12. Flip default backend to native while retaining bridge rollback path. +13. Complete soak window against cutover readiness gates. +14. Remove LiteLLM dependency and legacy overrides for non-Bedrock paths. ## Definition of Done @@ -1260,4 +1465,5 @@ The LiteLLM replacement is complete when all conditions are met: 4. Existing model-facing behavior tests pass against native adapters. 5. OpenAI-compatible and Anthropic adapters are available with documented capability limits. 6. Bedrock work is explicitly deferred to `plans/343/model-facade-overhaul-plan-step-2-bedrock.md`. -7. Documentation and CLI provider guidance are updated. +7. Adapter teardown lifecycle is wired and validated (no leaked open HTTP clients after `create`/`preview`/`validate` flows). +8. Documentation and CLI provider guidance are updated. diff --git a/plans/343/model-facade-overhaul-plan-step-2-bedrock.md b/plans/343/model-facade-overhaul-plan-step-2-bedrock.md index 8e7f1d5f2..61d7a8ed0 100644 --- a/plans/343/model-facade-overhaul-plan-step-2-bedrock.md +++ b/plans/343/model-facade-overhaul-plan-step-2-bedrock.md @@ -6,6 +6,8 @@ authors: # Model Facade Overhaul Plan: Step 2 (Bedrock) +Review Reference: `plans/343/_review-model-facade-overhaul-plan.md` + This step adds native Bedrock support after Step 1 is complete. Depends on: @@ -38,7 +40,7 @@ Step 2 extends the Step 1 client layer by adding Bedrock-specific paths only. ModelFacade -> ModelClient API -> [OpenAI Adapter | Anthropic Adapter | Bridge] Step 2 (new) -ModelFacade -> ModelClient API -> [OpenAI | Anthropic | Bedrock | Bridge?] +ModelFacade -> ModelClient API -> [OpenAI | Anthropic | Bedrock | Bridge (optional soak fallback)] | v Bedrock Mapper Layer @@ -98,6 +100,14 @@ Updated files: 3. Image generation: supported by model-family mapper. 4. Unsupported operation for chosen model -> canonical `ProviderError(kind=UNSUPPORTED_CAPABILITY)`. +### Sync/async execution contract + +1. Step 2 keeps existing `ModelFacade` sync/async signatures unchanged. +2. Bedrock adapter primary transport uses the sync AWS SDK client. +3. Async adapter methods wrap SDK calls through `asyncio.to_thread(...)` to avoid blocking the event loop. +4. Shared retry/throttle/auth logic remains identical between sync and async paths. +5. Streaming remains out of scope. + ### Model family mappers Implement Bedrock mappers per family: @@ -172,7 +182,9 @@ BedrockAuth = Annotated[ 1. `UnrecognizedClientException` -> `ProviderError(kind=AUTHENTICATION)` 2. `AccessDeniedException` -> `ProviderError(kind=PERMISSION_DENIED)` -3. STS assume-role failures -> `ProviderError(kind=AUTHENTICATION | PERMISSION_DENIED)` based on error code +3. STS assume-role failures: + - credential/identity failures -> `ProviderError(kind=AUTHENTICATION)` + - policy/authorization failures -> `ProviderError(kind=PERMISSION_DENIED)` ## Throttling and Concurrency @@ -194,6 +206,11 @@ Use the same shared throttling framework introduced in Step 1. 1. SDK throttling exceptions (`ThrottlingException` family) 2. retry hints when present in SDK metadata +### Lifecycle contract + +1. `BedrockClient` participates in Step 1 lifecycle ownership (`ModelRegistry` -> `ResourceProvider` teardown). +2. Any Bedrock runtime/session object that exposes close semantics is closed from adapter `close`/`aclose`. + ## Testing ### Unit tests From dfa38175c1bcaf3b5c8a771ad636d998052a2bd6 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 25 Feb 2026 15:15:06 -0700 Subject: [PATCH 05/27] update plan doc --- .../343/model-facade-overhaul-plan-step-1.md | 78 ++++++++++++++++++- 1 file changed, 75 insertions(+), 3 deletions(-) diff --git a/plans/343/model-facade-overhaul-plan-step-1.md b/plans/343/model-facade-overhaul-plan-step-1.md index 402dcce73..69172a010 100644 --- a/plans/343/model-facade-overhaul-plan-step-1.md +++ b/plans/343/model-facade-overhaul-plan-step-1.md @@ -456,6 +456,28 @@ Implementation expectations: 5. Normalize tool calls and reasoning fields. 6. Normalize image outputs from either `b64_json`, data URI, or URL download. +### Image routing ownership contract + +1. `ModelFacade` remains responsible for deciding diffusion-vs-chat image generation using current model semantics (`is_image_diffusion_model(model_name)` and multimodal context). +2. `ModelFacade` constructs `ImageGenerationRequest` accordingly: + - chat-based path: set `messages` and chat-oriented payload fields + - diffusion path: omit `messages` and use prompt/image endpoint payload fields +3. Adapter routing is intentionally dumb and request-shape-based: + - `request.messages is not None` -> chat route + - otherwise -> image-generation route + +### OpenAI-compatible image extraction waterfall + +`parse_openai_image_response` must handle all currently observed formats: + +1. `choices[0].message.images` entries as dicts containing nested `image_url`/data-URI data +2. `choices[0].message.images` entries as plain strings (base64 or data URI) +3. provider image objects with `b64_json` +4. provider image objects with `url` (requires outbound fetch + base64 normalization) +5. `choices[0].message.content` containing raw base64 image payloads + +URL-to-base64 fetches happen in adapter helpers, using a dedicated outbound HTTP fetch client (separate from provider API transport), with timeouts and redacted logging. + ### Anthropic adapter File: `clients/adapters/anthropic.py` @@ -524,6 +546,8 @@ Back-compat rule: 1. If `auth` is absent and `api_key` is present: - for `openai` and `anthropic`, treat as `auth.mode=api_key` +2. If both `auth` and `api_key` are absent for OpenAI-compatible providers: + - treat as explicit no-auth mode and do not send `Authorization` header. ### Proposed typed auth schema @@ -535,17 +559,24 @@ from typing import Literal from pydantic import BaseModel -class OpenAIAuth(BaseModel): +class OpenAIApiKeyAuth(BaseModel): mode: Literal["api_key"] = "api_key" api_key: str organization: str | None = None project: str | None = None +class OpenAINoAuth(BaseModel): + mode: Literal["none"] = "none" + + class AnthropicAuth(BaseModel): mode: Literal["api_key"] = "api_key" api_key: str anthropic_version: str = "2023-06-01" + + +OpenAIAuth = OpenAIApiKeyAuth | OpenAINoAuth ``` ### Adapter auth behavior by provider @@ -554,7 +585,7 @@ class AnthropicAuth(BaseModel): Headers: -1. `Authorization: Bearer ` +1. `Authorization: Bearer ` when `auth.mode == "api_key"` 2. Optional `OpenAI-Organization: ` 3. Optional `OpenAI-Project: ` @@ -615,7 +646,8 @@ Then map canonical provider errors to existing Data Designer user-facing errors: 1. Add optional `auth` field to `ModelProvider`. 2. Keep `api_key` as fallback. -3. Implement adapter builders that accept both forms. +3. Add `OpenAINoAuth` support for auth-optional OpenAI-compatible providers. +4. Implement adapter builders that accept all supported forms. #### Phase B @@ -735,6 +767,7 @@ Lifecycle contract: 3. `ResourceProvider` exposes `close`/`aclose` and delegates to `ModelRegistry` teardown. 4. `DataDesigner` entrypoints (`create`, `preview`, `validate`) invoke resource teardown in `finally` blocks. 5. Tests must verify no leaked open HTTP clients after teardown. +6. MCP session lifecycle is explicitly owned per run: `ResourceProvider.close()` invokes MCP registry/session-pool cleanup rather than relying only on process-level `atexit`. Pool sizing policy: @@ -879,6 +912,10 @@ To preserve current behavior, merge precedence is explicit: 4. Set `extra_headers` from provider extra headers (provider-level replacement semantics). 5. Drop non-provider params like `purpose` before outbound request. +Merge behavior note: + +1. `extra_body` merge is shallow (top-level keys only), matching current behavior. + ### Canonical -> OpenAI-compatible chat payload | Canonical field | OpenAI payload field | @@ -947,6 +984,15 @@ Migration safety option: 1. `provider_type == "litellm-bridge"` or feature flag -> `LiteLLMBridgeClient` 2. lets us verify new abstraction without immediate provider rewrite +### Bridge coexistence with LiteLLM global patches + +During mixed bridge/native rollout: + +1. `apply_litellm_patches()` must run if any configured model resolves to `LiteLLMBridgeClient`. +2. Patch application must be idempotent and safe when called multiple times. +3. `ThreadSafeCache` + LiteLLM patch behavior remains in place until PR-7 removes bridge/LiteLLM path. +4. PR-7 is the cleanup point for removing `litellm_overrides.py` patch side effects. + ## Error Model and Mapping Current `errors.py` pattern-matches LiteLLM exception types. Replace this with canonical provider exceptions. @@ -1098,6 +1144,13 @@ Use additive-increase / multiplicative-decrease (AIMD): 4. Re-enter acquire path for each retry attempt (so retries also obey adaptive limits). 5. Register each `ModelFacade` limit contribution into `GlobalCapState` during initialization. +### Timeout and cancellation semantics + +1. `inference_parameters.timeout` applies to outbound provider call duration per attempt (transport timeout), not queueing time before admission. +2. Throttle queue wait time is tracked separately in telemetry for observability. +3. Async throttle acquire loops must propagate `asyncio.CancelledError` immediately. +4. Retry loops must not swallow cancellation signals. + ### Config knobs (optional, with safe defaults) 1. `adaptive_throttle_enabled: bool = true` @@ -1123,6 +1176,12 @@ Minimal diff approach: 4. Consume canonical response shape. 5. Preserve MCP generation semantics (tool budget, refusal behavior, corrections, trace behavior) while updating MCP helper interfaces to canonical response fields. 6. Move usage tracking methods to consume canonical `Usage`. +7. Keep `consolidate_kwargs` as the merge source of truth and add typed request builders: + - `_build_chat_request(...)` + - `_build_embedding_request(...)` + - `_build_image_request(...)` + This preserves dynamic inference-parameter sampling semantics while moving transport to canonical request types. +8. Preserve async MCP wrapping pattern (`asyncio.to_thread`) for MCP tool schema retrieval and completion processing where MCP services are sync/event-loop-isolated. Expected code updates: @@ -1190,6 +1249,12 @@ This matrix defines expected parity between the current LiteLLM-backed implement | `ModelGenerationValidationFailureError` | Parser correction exhaustion | Same parser correction exhaustion path | Same class and retry/correction semantics | | `ImageGenerationError` | Image extraction or empty image payload | Same canonical image extraction validations | Same class and failure conditions | +Provider error body normalization requirements: + +1. OpenAI-compatible adapters parse structured error payload fields (`error.type`, `error.code`, `error.param`, `error.message`) into canonical `ProviderError.message`. +2. Anthropic adapters parse structured error payload fields (`error.type`, `error.message`) into canonical `ProviderError.message`. +3. Context-window and multimodal-capability hints should remain actionable in user-facing errors when provider payload includes equivalent context. + ### Usage and telemetry parity | Metric behavior | Current | Target | Compatibility expectation | @@ -1220,6 +1285,7 @@ These are allowed differences and should be documented when encountered: 2. Exact wording of low-level upstream exception strings. 3. Minor token accounting differences when providers omit or redefine usage fields. 4. Unsupported capability messages that are more explicit than current generic errors. +5. Cross-run persistence of adaptive throttle state. In Step 1, throttle state is per `ResourceProvider` lifetime and intentionally resets between independent `create`/`preview`/`validate` runs. ## Config and CLI Evolution @@ -1287,6 +1353,12 @@ Ensure `ModelRegistry.run_health_check()` still behaves correctly for: 2. embedding models 3. image models +Health-check throttle contract: + +1. Health checks go through the same adapter stack for realistic validation. +2. Health checks use a dedicated throttle domain (`healthcheck`) to reduce interference with workload traffic. +3. Health-check outcomes should not mutate adaptive AIMD state used by production generation traffic. + ### 5. Sync/async throttle parity tests 1. Shared throttle state is enforced across mixed sync and async calls for same key. From 0f449a7768615295063a7ca9b3e83fb9740959ff Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 25 Feb 2026 16:58:37 -0700 Subject: [PATCH 06/27] address nits --- .../343/_review-model-facade-overhaul-plan.md | 32 ++++++++--------- .../343/model-facade-overhaul-plan-step-1.md | 34 +++++++++++++------ 2 files changed, 40 insertions(+), 26 deletions(-) diff --git a/plans/343/_review-model-facade-overhaul-plan.md b/plans/343/_review-model-facade-overhaul-plan.md index 3f4a22418..aeaa32b4b 100644 --- a/plans/343/_review-model-facade-overhaul-plan.md +++ b/plans/343/_review-model-facade-overhaul-plan.md @@ -21,9 +21,9 @@ Request changes. The migration direction is solid, but several contradictions an ### HIGH: `completion`/`acompletion` response contract is not explicit enough for MCP compatibility Evidence: -- Step 1 says keep methods/signatures and preserve MCP loop unchanged (`plans/343/model-facade-overhaul-plan-step-1.md:924`, `plans/343/model-facade-overhaul-plan-step-1.md:928`). -- Step 1 also says ModelFacade will consume canonical response shapes (`plans/343/model-facade-overhaul-plan-step-1.md:927`) and maps LiteLLM fields into canonical message fields (`plans/343/model-facade-overhaul-plan-step-1.md:703`). -- Current MCP shape parity risk is explicitly called out (`plans/343/model-facade-overhaul-plan-step-1.md:1160`). +- Step 1 says keep methods/signatures and preserve MCP loop unchanged (`plans/343/model-facade-overhaul-plan-step-1.md`). +- Step 1 also says ModelFacade will consume canonical response shapes and maps LiteLLM fields into canonical message fields (`plans/343/model-facade-overhaul-plan-step-1.md`). +- Current MCP shape parity risk is explicitly called out (`plans/343/model-facade-overhaul-plan-step-1.md`). Recommendation: - Pick one explicit migration contract: @@ -34,8 +34,8 @@ Recommendation: ### HIGH: Adaptive throttling contract is internally inconsistent Evidence: -- Shared hard cap is defined as `min(max_parallel_requests...)` across aliases (`plans/343/model-facade-overhaul-plan-step-1.md:838`, `plans/343/model-facade-overhaul-plan-step-1.md:839`). -- Test plan also expects aliases on same provider/model to keep independent primary limits (`plans/343/model-facade-overhaul-plan-step-1.md:1093`). +- Shared hard cap is defined as `min(max_parallel_requests...)` across aliases (`plans/343/model-facade-overhaul-plan-step-1.md`). +- Test plan also expects aliases on same provider/model to keep independent primary limits (`plans/343/model-facade-overhaul-plan-step-1.md`). Recommendation: - Choose one contract and align tests: @@ -46,8 +46,8 @@ Recommendation: ### HIGH: Step 1 provider-type hardening conflicts with Step 2 Bedrock routing Evidence: -- Step 1 Phase B constrains provider type enum to `openai` and `anthropic` (`plans/343/model-facade-overhaul-plan-step-1.md:1038`). -- Step 2 requires `provider_type == "bedrock"` factory routing (`plans/343/model-facade-overhaul-plan-step-2-bedrock.md:32`). +- Step 1 Phase B constrains provider type enum to `openai` and `anthropic` (`plans/343/model-facade-overhaul-plan-step-1.md`). +- Step 2 requires `provider_type == "bedrock"` factory routing (`plans/343/model-facade-overhaul-plan-step-2-bedrock.md`). Recommendation: - Keep `provider_type` extensible through Step 1 (or include `bedrock` in planned enum values). @@ -56,9 +56,9 @@ Recommendation: ### HIGH: Rollback safety is inconsistent with default-flip/removal sequencing Evidence: -- PR slicing combines cutover default flip and LiteLLM removal in PR-6 (`plans/343/model-facade-overhaul-plan-step-1.md:136`). -- Rollback guardrail promises backend flag toggle (`plans/343/model-facade-overhaul-plan-step-1.md:1152`). -- Phase 3 removes LiteLLM runtime path/dependency (`plans/343/model-facade-overhaul-plan-step-1.md:1141`, `plans/343/model-facade-overhaul-plan-step-1.md:1142`). +- PR slicing combines cutover default flip and LiteLLM removal in PR-6 (`plans/343/model-facade-overhaul-plan-step-1.md`). +- Rollback guardrail promises backend flag toggle (`plans/343/model-facade-overhaul-plan-step-1.md`). +- Phase 3 removes LiteLLM runtime path/dependency (`plans/343/model-facade-overhaul-plan-step-1.md`). Recommendation: - Separate default flip from dependency/path removal. @@ -67,8 +67,8 @@ Recommendation: ### MEDIUM: Auth error mapping is ambiguous for `401/403` Evidence: -- Step 1 maps `401/403` to `AUTHENTICATION | PERMISSION_DENIED` with no deterministic rule (`plans/343/model-facade-overhaul-plan-step-1.md:508`, `plans/343/model-facade-overhaul-plan-step-1.md:509`). -- Step 2 has same ambiguity for STS failures (`plans/343/model-facade-overhaul-plan-step-2-bedrock.md:116`). +- Step 1 maps `401/403` to `AUTHENTICATION | PERMISSION_DENIED` with no deterministic rule (`plans/343/model-facade-overhaul-plan-step-1.md`). +- Step 2 has same ambiguity for STS failures (`plans/343/model-facade-overhaul-plan-step-2-bedrock.md`). Recommendation: - Define deterministic status/code mapping (default `401 -> AUTHENTICATION`, `403 -> PERMISSION_DENIED`) and document provider-specific exceptions. @@ -77,7 +77,7 @@ Recommendation: ### MEDIUM: HTTP client lifecycle and pool sizing are underspecified Evidence: -- HTTP adapter skeleton instantiates both sync/async httpx clients and exposes `close`/`aclose` (`plans/343/model-facade-overhaul-plan-step-1.md:583`, `plans/343/model-facade-overhaul-plan-step-1.md:584`, `plans/343/model-facade-overhaul-plan-step-1.md:594`, `plans/343/model-facade-overhaul-plan-step-1.md:597`). +- HTTP adapter skeleton instantiates both sync/async httpx clients and exposes `close`/`aclose` (`plans/343/model-facade-overhaul-plan-step-1.md`). - Plan does not define owner/teardown integration or pool sizing policy. Recommendation: @@ -87,7 +87,7 @@ Recommendation: ### MEDIUM: `extra_body`/`extra_headers` precedence needs explicit contract Evidence: -- Plan promises no precedence drift but does not state merge order (`plans/343/model-facade-overhaul-plan-step-1.md:975`, `plans/343/model-facade-overhaul-plan-step-1.md:1014`, `plans/343/model-facade-overhaul-plan-step-1.md:1015`). +- Plan promises no precedence drift but does not state merge order (`plans/343/model-facade-overhaul-plan-step-1.md`). - Existing implementation has specific precedence rules (provider overrides for `extra_body`, provider replacement for `extra_headers`). Recommendation: @@ -96,8 +96,8 @@ Recommendation: ### MEDIUM: Anthropic capability statements conflict within Step 1 Evidence: -- Anthropic adapter skeleton marks embeddings and image generation unsupported (`plans/343/model-facade-overhaul-plan-step-1.md:656`, `plans/343/model-facade-overhaul-plan-step-1.md:660`, `plans/343/model-facade-overhaul-plan-step-1.md:674`, `plans/343/model-facade-overhaul-plan-step-1.md:680`). -- Capability matrix says Anthropic support is "Provider dependent" for those operations (`plans/343/model-facade-overhaul-plan-step-1.md:946`, `plans/343/model-facade-overhaul-plan-step-1.md:947`, `plans/343/model-facade-overhaul-plan-step-1.md:948`). +- Anthropic adapter skeleton marks embeddings and image generation unsupported (`plans/343/model-facade-overhaul-plan-step-1.md`). +- Capability matrix says Anthropic support is "Provider dependent" for those operations (`plans/343/model-facade-overhaul-plan-step-1.md`). Recommendation: - Align matrix with Step 1 implementation scope, or define exact conditions and gates for conditional support. diff --git a/plans/343/model-facade-overhaul-plan-step-1.md b/plans/343/model-facade-overhaul-plan-step-1.md index 69172a010..b932cca75 100644 --- a/plans/343/model-facade-overhaul-plan-step-1.md +++ b/plans/343/model-facade-overhaul-plan-step-1.md @@ -267,10 +267,11 @@ Suggested files: 2. `types.py` - Canonical request/response objects 3. `errors.py` - Provider-agnostic transport/provider exceptions 4. `retry.py` - Backoff policy and retry decision logic -5. `factory.py` - Adapter selection by `provider_type` -6. `adapters/openai_compatible.py` -7. `adapters/anthropic.py` -8. `adapters/litellm_bridge.py` (temporary bridge for migration safety) +5. `throttle.py` - Adaptive concurrency state and AIMD controller +6. `factory.py` - Adapter selection by `provider_type` +7. `adapters/openai_compatible.py` +8. `adapters/anthropic.py` +9. `adapters/litellm_bridge.py` (temporary bridge for migration safety) ### 2. Keep `ModelFacade` as orchestrator @@ -554,9 +555,9 @@ Back-compat rule: ```python from __future__ import annotations -from typing import Literal +from typing import Annotated, Literal -from pydantic import BaseModel +from pydantic import BaseModel, Field class OpenAIApiKeyAuth(BaseModel): @@ -576,7 +577,7 @@ class AnthropicAuth(BaseModel): anthropic_version: str = "2023-06-01" -OpenAIAuth = OpenAIApiKeyAuth | OpenAINoAuth +OpenAIAuth = Annotated[OpenAIApiKeyAuth | OpenAINoAuth, Field(discriminator="mode")] ``` ### Adapter auth behavior by provider @@ -706,10 +707,14 @@ class HTTPAdapterBase: self._client = httpx.Client(timeout=self.timeout_s) self._aclient = httpx.AsyncClient(timeout=self.timeout_s) + def _auth_headers(self) -> dict[str, str]: + if not self.api_key: + return {} + return {"Authorization": f"Bearer {self.api_key}"} + def _headers(self, extra_headers: dict[str, str] | None = None) -> dict[str, str]: headers = dict(self.default_headers) - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" + headers.update(self._auth_headers()) if extra_headers: headers.update(extra_headers) return headers @@ -876,10 +881,19 @@ class AnthropicClient(HTTPAdapterBase, ModelClient): def supports_image_generation(self) -> bool: return False + def _auth_headers(self) -> dict[str, str]: + headers = { + "anthropic-version": "2023-06-01", + "content-type": "application/json", + } + if self.api_key: + headers["x-api-key"] = self.api_key + return headers + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: payload = anthropic_payload_from_canonical(request) response_json = run_with_retries( - fn=lambda: self._post_json("chat", payload, anthropic_headers(request.extra_headers)), + fn=lambda: self._post_json("chat", payload, request.extra_headers), policy=self.retry_policy, ) return parse_anthropic_chat_response(response_json) From 08e57f83f83bab339572977efc3eb5685b499687 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 25 Feb 2026 17:00:56 -0700 Subject: [PATCH 07/27] Add cannonical objects --- .../engine/models/clients/__init__.py | 40 ++ .../models/clients/adapters/__init__.py | 6 + .../models/clients/adapters/litellm_bridge.py | 389 ++++++++++++++++++ .../engine/models/clients/base.py | 43 ++ .../engine/models/clients/errors.py | 141 +++++++ .../engine/models/clients/types.py | 94 +++++ .../models/clients/test_client_errors.py | 101 +++++ .../models/clients/test_litellm_bridge.py | 214 ++++++++++ ...facade-overhaul-pr-1-architecture-notes.md | 47 +++ 9 files changed, 1075 insertions(+) create mode 100644 packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/models/clients/base.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py create mode 100644 packages/data-designer-engine/src/data_designer/engine/models/clients/types.py create mode 100644 packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py create mode 100644 packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py create mode 100644 plans/343/model-facade-overhaul-pr-1-architecture-notes.md diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py new file mode 100644 index 000000000..4abdde859 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.engine.models.clients.base import ModelClient +from data_designer.engine.models.clients.errors import ( + ProviderError, + ProviderErrorKind, + map_http_error_to_provider_error, + map_http_status_to_provider_error_kind, +) +from data_designer.engine.models.clients.types import ( + AssistantMessage, + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ImageGenerationRequest, + ImageGenerationResponse, + ImagePayload, + ToolCall, + Usage, +) + +__all__ = [ + "AssistantMessage", + "ChatCompletionRequest", + "ChatCompletionResponse", + "EmbeddingRequest", + "EmbeddingResponse", + "ImageGenerationRequest", + "ImageGenerationResponse", + "ImagePayload", + "ModelClient", + "ProviderError", + "ProviderErrorKind", + "ToolCall", + "Usage", + "map_http_error_to_provider_error", + "map_http_status_to_provider_error_kind", +] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py new file mode 100644 index 000000000..0ecbae287 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient + +__all__ = ["LiteLLMBridgeClient"] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py new file mode 100644 index 000000000..740269237 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -0,0 +1,389 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import logging +from typing import Any + +from data_designer.config.utils.image_helpers import ( + extract_base64_from_data_uri, + is_base64_image, + load_image_url_to_base64, +) +from data_designer.engine.models.clients.base import ModelClient +from data_designer.engine.models.clients.types import ( + AssistantMessage, + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ImageGenerationRequest, + ImageGenerationResponse, + ImagePayload, + ToolCall, + Usage, +) + +logger = logging.getLogger(__name__) + + +class LiteLLMBridgeClient(ModelClient): + """Bridge adapter that wraps the existing LiteLLM router behind canonical client types.""" + + def __init__(self, *, provider_name: str, router: Any) -> None: + self.provider_name = provider_name + self._router = router + + def supports_chat_completion(self) -> bool: + return True + + def supports_embeddings(self) -> bool: + return True + + def supports_image_generation(self) -> bool: + return True + + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + response = self._router.completion( + model=request.model, + messages=request.messages, + **_chat_request_kwargs(request), + ) + return _parse_chat_completion_response(response) + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + response = await self._router.acompletion( + model=request.model, + messages=request.messages, + **_chat_request_kwargs(request), + ) + return _parse_chat_completion_response(response) + + def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + response = self._router.embedding( + model=request.model, + input=request.inputs, + **_embedding_request_kwargs(request), + ) + vectors = [_extract_embedding_vector(item) for item in getattr(response, "data", [])] + return EmbeddingResponse(vectors=vectors, usage=_extract_usage(getattr(response, "usage", None)), raw=response) + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + response = await self._router.aembedding( + model=request.model, + input=request.inputs, + **_embedding_request_kwargs(request), + ) + vectors = [_extract_embedding_vector(item) for item in getattr(response, "data", [])] + return EmbeddingResponse(vectors=vectors, usage=_extract_usage(getattr(response, "usage", None)), raw=response) + + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + if request.messages is not None: + response = self._router.completion( + model=request.model, + messages=request.messages, + **_image_chat_kwargs(request), + ) + images = _extract_images_from_chat_response(response) + else: + response = self._router.image_generation( + prompt=request.prompt, + model=request.model, + **_image_request_kwargs(request), + ) + images = _extract_images_from_image_response(response) + + usage = _extract_usage(getattr(response, "usage", None), generated_images=len(images)) + return ImageGenerationResponse(images=images, usage=usage, raw=response) + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + if request.messages is not None: + response = await self._router.acompletion( + model=request.model, + messages=request.messages, + **_image_chat_kwargs(request), + ) + images = _extract_images_from_chat_response(response) + else: + response = await self._router.aimage_generation( + prompt=request.prompt, + model=request.model, + **_image_request_kwargs(request), + ) + images = _extract_images_from_image_response(response) + + usage = _extract_usage(getattr(response, "usage", None), generated_images=len(images)) + return ImageGenerationResponse(images=images, usage=usage, raw=response) + + def close(self) -> None: + return None + + async def aclose(self) -> None: + return None + + +def _parse_chat_completion_response(response: Any) -> ChatCompletionResponse: + first_choice = _first_or_none(getattr(response, "choices", None)) + message = _value_from(first_choice, "message") + tool_calls = _extract_tool_calls(_value_from(message, "tool_calls")) + images = _extract_images_from_chat_message(message) + assistant_message = AssistantMessage( + content=_coerce_message_content(_value_from(message, "content")), + reasoning_content=_value_from(message, "reasoning_content"), + tool_calls=tool_calls, + images=images, + ) + usage = _extract_usage(getattr(response, "usage", None), generated_images=len(images) if images else None) + return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) + + +def _chat_request_kwargs(request: ChatCompletionRequest) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + if request.tools is not None: + kwargs["tools"] = request.tools + if request.temperature is not None: + kwargs["temperature"] = request.temperature + if request.top_p is not None: + kwargs["top_p"] = request.top_p + if request.max_tokens is not None: + kwargs["max_tokens"] = request.max_tokens + if request.timeout is not None: + kwargs["timeout"] = request.timeout + if request.extra_body is not None: + kwargs["extra_body"] = request.extra_body + if request.extra_headers is not None: + kwargs["extra_headers"] = request.extra_headers + if request.metadata is not None: + kwargs["metadata"] = request.metadata + return kwargs + + +def _embedding_request_kwargs(request: EmbeddingRequest) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + if request.encoding_format is not None: + kwargs["encoding_format"] = request.encoding_format + if request.dimensions is not None: + kwargs["dimensions"] = request.dimensions + if request.timeout is not None: + kwargs["timeout"] = request.timeout + if request.extra_body is not None: + kwargs["extra_body"] = request.extra_body + if request.extra_headers is not None: + kwargs["extra_headers"] = request.extra_headers + return kwargs + + +def _image_request_kwargs(request: ImageGenerationRequest) -> dict[str, Any]: + kwargs: dict[str, Any] = {} + if request.n is not None: + kwargs["n"] = request.n + if request.timeout is not None: + kwargs["timeout"] = request.timeout + if request.extra_body is not None: + kwargs["extra_body"] = request.extra_body + if request.extra_headers is not None: + kwargs["extra_headers"] = request.extra_headers + return kwargs + + +def _image_chat_kwargs(request: ImageGenerationRequest) -> dict[str, Any]: + kwargs = _image_request_kwargs(request) + if request.extra_body is not None: + kwargs["extra_body"] = request.extra_body + return kwargs + + +def _extract_embedding_vector(item: Any) -> list[float]: + value = _value_from(item, "embedding") + if isinstance(value, list): + return [float(v) for v in value] + return [] + + +def _extract_tool_calls(raw_tool_calls: Any) -> list[ToolCall]: + if not raw_tool_calls: + return [] + + normalized_tool_calls: list[ToolCall] = [] + for raw_tool_call in raw_tool_calls: + tool_call_id = _value_from(raw_tool_call, "id") or "" + function = _value_from(raw_tool_call, "function") + name = _value_from(function, "name") or "" + arguments_value = _value_from(function, "arguments") + arguments_json = _serialize_tool_arguments(arguments_value) + normalized_tool_calls.append(ToolCall(id=str(tool_call_id), name=str(name), arguments_json=arguments_json)) + + return normalized_tool_calls + + +def _serialize_tool_arguments(arguments_value: Any) -> str: + if arguments_value is None: + return "{}" + if isinstance(arguments_value, str): + return arguments_value + try: + return json.dumps(arguments_value) + except Exception: + return str(arguments_value) + + +def _extract_images_from_chat_response(response: Any) -> list[ImagePayload]: + first_choice = _first_or_none(getattr(response, "choices", None)) + message = _value_from(first_choice, "message") + return _extract_images_from_chat_message(message) + + +def _extract_images_from_chat_message(message: Any) -> list[ImagePayload]: + images: list[ImagePayload] = [] + + raw_images = _value_from(message, "images") + if isinstance(raw_images, list): + for raw_image in raw_images: + parsed_image = _parse_image_payload(raw_image) + if parsed_image is not None: + images.append(parsed_image) + + if images: + return images + + raw_content = _value_from(message, "content") + if isinstance(raw_content, str): + parsed_image = _parse_image_payload(raw_content) + if parsed_image is not None: + images.append(parsed_image) + + return images + + +def _extract_images_from_image_response(response: Any) -> list[ImagePayload]: + images: list[ImagePayload] = [] + for raw_image in getattr(response, "data", []): + parsed_image = _parse_image_payload(raw_image) + if parsed_image is not None: + images.append(parsed_image) + return images + + +def _parse_image_payload(raw_image: Any) -> ImagePayload | None: + try: + if isinstance(raw_image, str): + return _parse_image_string(raw_image) + + if isinstance(raw_image, dict): + if "b64_json" in raw_image and isinstance(raw_image["b64_json"], str): + return ImagePayload(b64_data=raw_image["b64_json"], mime_type=None) + if "image_url" in raw_image: + return _parse_image_payload(raw_image["image_url"]) + if "url" in raw_image and isinstance(raw_image["url"], str): + return _parse_image_string(raw_image["url"]) + + b64_json = _value_from(raw_image, "b64_json") + if isinstance(b64_json, str): + return ImagePayload(b64_data=b64_json, mime_type=None) + + url = _value_from(raw_image, "url") + if isinstance(url, str): + return _parse_image_string(url) + except Exception: + logger.debug("Unable to parse image payload from bridge response object.", exc_info=True) + + return None + + +def _parse_image_string(raw_value: str) -> ImagePayload | None: + if raw_value.startswith("data:image/"): + return ImagePayload( + b64_data=extract_base64_from_data_uri(raw_value), + mime_type=_extract_mime_type_from_data_uri(raw_value), + ) + + if is_base64_image(raw_value): + return ImagePayload(b64_data=raw_value, mime_type=None) + + if raw_value.startswith(("http://", "https://")): + b64_data = load_image_url_to_base64(raw_value) + return ImagePayload(b64_data=b64_data, mime_type=None) + + return None + + +def _extract_mime_type_from_data_uri(data_uri: str) -> str | None: + if not data_uri.startswith("data:"): + return None + head = data_uri.split(",", maxsplit=1)[0] + if ";" in head: + return head[5:].split(";", maxsplit=1)[0] + return head[5:] or None + + +def _extract_usage(raw_usage: Any, generated_images: int | None = None) -> Usage | None: + if raw_usage is None and generated_images is None: + return None + + input_tokens = _value_from(raw_usage, "prompt_tokens") + output_tokens = _value_from(raw_usage, "completion_tokens") + total_tokens = _value_from(raw_usage, "total_tokens") + + if input_tokens is None: + input_tokens = _value_from(raw_usage, "input_tokens") + if output_tokens is None: + output_tokens = _value_from(raw_usage, "output_tokens") + + if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int): + total_tokens = input_tokens + output_tokens + + if generated_images is None: + generated_images = _value_from(raw_usage, "generated_images") + if generated_images is None and raw_usage is not None: + generated_images = _value_from(raw_usage, "images") + + if input_tokens is None and output_tokens is None and total_tokens is None and generated_images is None: + return None + + return Usage( + input_tokens=_to_int_or_none(input_tokens), + output_tokens=_to_int_or_none(output_tokens), + total_tokens=_to_int_or_none(total_tokens), + generated_images=_to_int_or_none(generated_images), + ) + + +def _coerce_message_content(content: Any) -> str | None: + if content is None: + return None + if isinstance(content, str): + return content + if isinstance(content, list): + text_parts: list[str] = [] + for block in content: + if isinstance(block, dict): + text_value = block.get("text") + if isinstance(text_value, str): + text_parts.append(text_value) + if text_parts: + return "\n".join(text_parts) + return str(content) + + +def _to_int_or_none(value: Any) -> int | None: + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + return None + + +def _value_from(source: Any, key: str) -> Any: + if source is None: + return None + if isinstance(source, dict): + return source.get(key) + return getattr(source, key, None) + + +def _first_or_none(values: Any) -> Any | None: + if isinstance(values, list) and values: + return values[0] + return None diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/base.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/base.py new file mode 100644 index 000000000..516797e01 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/base.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Protocol + +from data_designer.engine.models.clients.types import ( + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ImageGenerationRequest, + ImageGenerationResponse, +) + + +class ChatCompletionClient(Protocol): + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: ... + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: ... + + +class EmbeddingClient(Protocol): + def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: ... + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: ... + + +class ImageGenerationClient(Protocol): + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: ... + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: ... + + +class ModelClient(ChatCompletionClient, EmbeddingClient, ImageGenerationClient, Protocol): + provider_name: str + + def supports_chat_completion(self) -> bool: ... + + def supports_embeddings(self) -> bool: ... + + def supports_image_generation(self) -> bool: ... diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py new file mode 100644 index 000000000..b8753c0af --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Any + + +class ProviderErrorKind(str, Enum): + API_ERROR = "api_error" + API_CONNECTION = "api_connection" + AUTHENTICATION = "authentication" + CONTEXT_WINDOW_EXCEEDED = "context_window_exceeded" + UNSUPPORTED_PARAMS = "unsupported_params" + BAD_REQUEST = "bad_request" + INTERNAL_SERVER = "internal_server" + NOT_FOUND = "not_found" + PERMISSION_DENIED = "permission_denied" + RATE_LIMIT = "rate_limit" + TIMEOUT = "timeout" + UNPROCESSABLE_ENTITY = "unprocessable_entity" + UNSUPPORTED_CAPABILITY = "unsupported_capability" + + +@dataclass +class ProviderError(Exception): + kind: ProviderErrorKind + message: str + status_code: int | None = None + provider_name: str | None = None + model_name: str | None = None + cause: Exception | None = None + + def __str__(self) -> str: + return self.message + + @classmethod + def unsupported_capability( + cls, + *, + provider_name: str, + operation: str, + model_name: str | None = None, + message: str | None = None, + ) -> ProviderError: + if message is None: + model_segment = f" for model {model_name!r}" if model_name else "" + message = f"Provider {provider_name!r} does not support operation {operation!r}{model_segment}." + return cls( + kind=ProviderErrorKind.UNSUPPORTED_CAPABILITY, + message=message, + provider_name=provider_name, + model_name=model_name, + ) + + +def map_http_status_to_provider_error_kind(status_code: int, body_text: str = "") -> ProviderErrorKind: + text = body_text.lower() + if status_code == 401: + return ProviderErrorKind.AUTHENTICATION + if status_code == 403: + return ProviderErrorKind.PERMISSION_DENIED + if status_code == 404: + return ProviderErrorKind.NOT_FOUND + if status_code == 408: + return ProviderErrorKind.TIMEOUT + if status_code == 413 or (status_code == 400 and _looks_like_context_window_error(text)): + return ProviderErrorKind.CONTEXT_WINDOW_EXCEEDED + if status_code == 422: + return ProviderErrorKind.UNPROCESSABLE_ENTITY + if status_code == 429: + return ProviderErrorKind.RATE_LIMIT + if status_code == 400: + return ProviderErrorKind.BAD_REQUEST + if 500 <= status_code <= 599: + return ProviderErrorKind.INTERNAL_SERVER + return ProviderErrorKind.API_ERROR + + +def map_http_error_to_provider_error( + *, + response: Any, + provider_name: str, + model_name: str | None = None, +) -> ProviderError: + status_code: int | None = getattr(response, "status_code", None) + body_text = _extract_response_text(response) + + if status_code is None: + return ProviderError( + kind=ProviderErrorKind.API_ERROR, + message=f"Provider {provider_name!r} request failed with an unknown HTTP status.", + provider_name=provider_name, + model_name=model_name, + ) + + kind = map_http_status_to_provider_error_kind(status_code=status_code, body_text=body_text) + return ProviderError( + kind=kind, + message=body_text or f"Provider {provider_name!r} request failed with status code {status_code}.", + status_code=status_code, + provider_name=provider_name, + model_name=model_name, + ) + + +def _extract_response_text(response: Any) -> str: + response_text = getattr(response, "text", None) + if isinstance(response_text, str) and response_text.strip(): + return response_text.strip() + + try: + payload = response.json() + except Exception: + return "" + + if isinstance(payload, dict): + for key in ("message", "error", "detail"): + value = payload.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + if isinstance(value, dict): + nested_message = value.get("message") + if isinstance(nested_message, str) and nested_message.strip(): + return nested_message.strip() + return "" + + +def _looks_like_context_window_error(text: str) -> bool: + return any( + token in text + for token in ( + "context window", + "context length", + "maximum context", + "too many tokens", + "max tokens", + ) + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py new file mode 100644 index 000000000..1d269b0a7 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py @@ -0,0 +1,94 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class Usage: + input_tokens: int | None = None + output_tokens: int | None = None + total_tokens: int | None = None + generated_images: int | None = None + + +@dataclass +class ImagePayload: + # Canonical output shape to upper layers is base64 without data URI prefix. + b64_data: str + mime_type: str | None = None + + +@dataclass +class ToolCall: + id: str + name: str + arguments_json: str + + +@dataclass +class AssistantMessage: + content: str | None = None + reasoning_content: str | None = None + tool_calls: list[ToolCall] = field(default_factory=list) + images: list[ImagePayload] = field(default_factory=list) + + +@dataclass +class ChatCompletionRequest: + model: str + messages: list[dict[str, Any]] + tools: list[dict[str, Any]] | None = None + temperature: float | None = None + top_p: float | None = None + max_tokens: int | None = None + timeout: float | None = None + extra_body: dict[str, Any] | None = None + extra_headers: dict[str, str] | None = None + metadata: dict[str, Any] | None = None + + +@dataclass +class ChatCompletionResponse: + message: AssistantMessage + usage: Usage | None = None + raw: Any | None = None + + +@dataclass +class EmbeddingRequest: + model: str + inputs: list[str] + encoding_format: str | None = None + dimensions: int | None = None + timeout: float | None = None + extra_body: dict[str, Any] | None = None + extra_headers: dict[str, str] | None = None + + +@dataclass +class EmbeddingResponse: + vectors: list[list[float]] + usage: Usage | None = None + raw: Any | None = None + + +@dataclass +class ImageGenerationRequest: + model: str + prompt: str + messages: list[dict[str, Any]] | None = None + n: int | None = None + timeout: float | None = None + extra_body: dict[str, Any] | None = None + extra_headers: dict[str, str] | None = None + + +@dataclass +class ImageGenerationResponse: + images: list[ImagePayload] + usage: Usage | None = None + raw: Any | None = None diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py new file mode 100644 index 000000000..6532a2900 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any + +import pytest + +from data_designer.engine.models.clients.errors import ( + ProviderError, + ProviderErrorKind, + map_http_error_to_provider_error, + map_http_status_to_provider_error_kind, +) + + +class StubHttpResponse: + def __init__(self, *, status_code: int, text: str = "", json_payload: dict[str, Any] | None = None) -> None: + self.status_code = status_code + self.text = text + self._json_payload = json_payload + + def json(self) -> dict[str, Any]: + if self._json_payload is None: + raise ValueError("No JSON payload") + return self._json_payload + + +@pytest.mark.parametrize( + "status_code,body_text,expected_kind", + [ + (401, "", ProviderErrorKind.AUTHENTICATION), + (403, "", ProviderErrorKind.PERMISSION_DENIED), + (404, "", ProviderErrorKind.NOT_FOUND), + (408, "", ProviderErrorKind.TIMEOUT), + (413, "", ProviderErrorKind.CONTEXT_WINDOW_EXCEEDED), + (422, "", ProviderErrorKind.UNPROCESSABLE_ENTITY), + (429, "", ProviderErrorKind.RATE_LIMIT), + (400, "", ProviderErrorKind.BAD_REQUEST), + (400, "maximum context length exceeded", ProviderErrorKind.CONTEXT_WINDOW_EXCEEDED), + (500, "", ProviderErrorKind.INTERNAL_SERVER), + (503, "", ProviderErrorKind.INTERNAL_SERVER), + (418, "", ProviderErrorKind.API_ERROR), + ], +) +def test_map_http_status_to_provider_error_kind( + status_code: int, + body_text: str, + expected_kind: ProviderErrorKind, +) -> None: + assert map_http_status_to_provider_error_kind(status_code=status_code, body_text=body_text) == expected_kind + + +def test_map_http_error_to_provider_error_uses_text_payload() -> None: + response = StubHttpResponse(status_code=429, text="Rate limit hit") + error = map_http_error_to_provider_error(response=response, provider_name="stub-provider", model_name="stub-model") + assert isinstance(error, ProviderError) + assert error.kind == ProviderErrorKind.RATE_LIMIT + assert error.message == "Rate limit hit" + assert error.status_code == 429 + assert error.provider_name == "stub-provider" + assert error.model_name == "stub-model" + + +def test_map_http_error_to_provider_error_uses_json_payload_when_text_missing() -> None: + response = StubHttpResponse( + status_code=403, + text="", + json_payload={"error": "Insufficient permissions for model"}, + ) + error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") + assert error.kind == ProviderErrorKind.PERMISSION_DENIED + assert error.message == "Insufficient permissions for model" + + +def test_map_http_error_to_provider_error_uses_nested_error_message_payload() -> None: + response = StubHttpResponse( + status_code=400, + text="", + json_payload={ + "error": { + "type": "invalid_request_error", + "message": "The request payload is invalid.", + } + }, + ) + error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") + assert error.kind == ProviderErrorKind.BAD_REQUEST + assert error.message == "The request payload is invalid." + + +def test_provider_error_unsupported_capability_helper() -> None: + error = ProviderError.unsupported_capability( + provider_name="stub-provider", + operation="image-generation", + model_name="stub-model", + ) + assert error.kind == ProviderErrorKind.UNSUPPORTED_CAPABILITY + assert "image-generation" in error.message + assert "stub-model" in error.message diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py new file mode 100644 index 000000000..23c78ce1a --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py @@ -0,0 +1,214 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient +from data_designer.engine.models.clients.types import ( + ChatCompletionRequest, + EmbeddingRequest, + ImageGenerationRequest, +) + + +def test_completion_maps_canonical_fields_from_litellm_response() -> None: + response = _build_chat_response( + content="final answer", + reasoning_content="reasoning trace", + tool_calls=[{"id": "call-1", "function": {"name": "lookup", "arguments": '{"query":"foo"}'}}], + usage=SimpleNamespace(prompt_tokens=11, completion_tokens=13, total_tokens=24), + ) + router = MagicMock() + router.completion.return_value = response + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + + request = ChatCompletionRequest( + model="stub-model", + messages=[{"role": "user", "content": "hello"}], + tools=[{"type": "function", "function": {"name": "lookup"}}], + temperature=0.2, + top_p=0.8, + max_tokens=256, + extra_body={"foo": "bar"}, + extra_headers={"x-trace": "1"}, + metadata={"trace_id": "abc"}, + ) + result = client.completion(request) + + assert result.message.content == "final answer" + assert result.message.reasoning_content == "reasoning trace" + assert len(result.message.tool_calls) == 1 + assert result.message.tool_calls[0].id == "call-1" + assert result.message.tool_calls[0].name == "lookup" + assert result.message.tool_calls[0].arguments_json == '{"query":"foo"}' + assert result.usage is not None + assert result.usage.input_tokens == 11 + assert result.usage.output_tokens == 13 + assert result.usage.total_tokens == 24 + assert result.raw is response + + router.completion.assert_called_once_with( + model="stub-model", + messages=[{"role": "user", "content": "hello"}], + tools=[{"type": "function", "function": {"name": "lookup"}}], + temperature=0.2, + top_p=0.8, + max_tokens=256, + extra_body={"foo": "bar"}, + extra_headers={"x-trace": "1"}, + metadata={"trace_id": "abc"}, + ) + + +@pytest.mark.asyncio +async def test_acompletion_maps_canonical_fields_from_litellm_response() -> None: + response = _build_chat_response(content="async result", reasoning_content=None, tool_calls=[], usage=None) + router = MagicMock() + router.acompletion = AsyncMock(return_value=response) + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + + request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) + result = await client.acompletion(request) + + assert result.message.content == "async result" + assert result.usage is None + router.acompletion.assert_awaited_once_with( + model="stub-model", + messages=[{"role": "user", "content": "hello"}], + ) + + +def test_embeddings_maps_vectors_and_usage() -> None: + response = SimpleNamespace( + data=[{"embedding": [1, 2]}, SimpleNamespace(embedding=[3.5, 4.5])], + usage=SimpleNamespace(prompt_tokens=4, total_tokens=4), + ) + router = MagicMock() + router.embedding.return_value = response + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + + request = EmbeddingRequest(model="stub-model", inputs=["a", "b"], dimensions=32, encoding_format="float") + result = client.embeddings(request) + + assert result.vectors == [[1.0, 2.0], [3.5, 4.5]] + assert result.usage is not None + assert result.usage.input_tokens == 4 + assert result.usage.output_tokens is None + assert result.raw is response + router.embedding.assert_called_once_with( + model="stub-model", + input=["a", "b"], + encoding_format="float", + dimensions=32, + ) + + +def test_generate_image_uses_chat_completion_path_when_messages_are_present() -> None: + response = _build_chat_response( + content=None, + reasoning_content=None, + tool_calls=None, + images=[{"image_url": {"url": "data:image/png;base64,aGVsbG8="}}], + usage=None, + ) + router = MagicMock() + router.completion.return_value = response + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + + request = ImageGenerationRequest( + model="stub-model", + prompt="unused because messages are supplied", + messages=[{"role": "user", "content": "generate image"}], + n=1, + ) + result = client.generate_image(request) + + assert len(result.images) == 1 + assert result.images[0].b64_data == "aGVsbG8=" + assert result.images[0].mime_type == "image/png" + assert result.usage is not None + assert result.usage.generated_images == 1 + router.completion.assert_called_once_with( + model="stub-model", + messages=[{"role": "user", "content": "generate image"}], + n=1, + ) + router.image_generation.assert_not_called() + + +def test_generate_image_uses_chat_completion_path_when_messages_is_empty_list() -> None: + response = _build_chat_response( + content=None, + reasoning_content=None, + tool_calls=None, + images=[{"image_url": {"url": "data:image/png;base64,aGVsbG8="}}], + usage=None, + ) + router = MagicMock() + router.completion.return_value = response + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + + request = ImageGenerationRequest( + model="stub-model", + prompt="unused because messages are supplied", + messages=[], + n=1, + ) + result = client.generate_image(request) + + assert len(result.images) == 1 + router.completion.assert_called_once_with( + model="stub-model", + messages=[], + n=1, + ) + router.image_generation.assert_not_called() + + +def test_generate_image_uses_diffusion_path_without_messages() -> None: + response = SimpleNamespace( + data=[ + SimpleNamespace(b64_json="Zmlyc3Q="), + {"url": "data:image/jpeg;base64,c2Vjb25k"}, + ], + usage=SimpleNamespace(input_tokens=9, output_tokens=12), + ) + router = MagicMock() + router.image_generation.return_value = response + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + + request = ImageGenerationRequest(model="stub-model", prompt="make an image", n=2) + result = client.generate_image(request) + + assert [image.b64_data for image in result.images] == ["Zmlyc3Q=", "c2Vjb25k"] + assert [image.mime_type for image in result.images] == [None, "image/jpeg"] + assert result.usage is not None + assert result.usage.input_tokens == 9 + assert result.usage.output_tokens == 12 + assert result.usage.total_tokens == 21 + assert result.usage.generated_images == 2 + router.image_generation.assert_called_once_with(prompt="make an image", model="stub-model", n=2) + + +def _build_chat_response( + *, + content: str | None, + reasoning_content: str | None, + tool_calls: list[dict[str, Any]] | None, + usage: Any, + images: list[dict[str, Any]] | None = None, +) -> Any: + message = SimpleNamespace( + content=content, + reasoning_content=reasoning_content, + tool_calls=tool_calls, + images=images, + ) + choice = SimpleNamespace(message=message) + return SimpleNamespace(choices=[choice], usage=usage) diff --git a/plans/343/model-facade-overhaul-pr-1-architecture-notes.md b/plans/343/model-facade-overhaul-pr-1-architecture-notes.md new file mode 100644 index 000000000..2391dd0a7 --- /dev/null +++ b/plans/343/model-facade-overhaul-pr-1-architecture-notes.md @@ -0,0 +1,47 @@ +--- +date: 2026-02-25 +authors: + - nmulepati +--- + +# Model Facade Overhaul PR-1 Architecture Notes + +This document captures the architecture intent for PR-1 from +`plans/343/model-facade-overhaul-plan-step-1.md`. + +## Canonical Adapter Boundary + +PR-1 introduces an internal `ModelClient` boundary under: + +`packages/data-designer-engine/src/data_designer/engine/models/clients/` + +Boundary contract: + +1. `ModelFacade`-facing requests/responses use canonical dataclasses in `clients/types.py`. +2. Provider SDK and transport-specific response shapes do not leak above the adapter layer. +3. Provider failures normalize to canonical provider errors (`ProviderError`, `ProviderErrorKind`). + +Canonical operation types in PR-1: + +1. Chat completion (`ChatCompletionRequest` / `ChatCompletionResponse`) +2. Embeddings (`EmbeddingRequest` / `EmbeddingResponse`) +3. Image generation (`ImageGenerationRequest` / `ImageGenerationResponse`) + +## LiteLLM Bridge Purpose + +`LiteLLMBridgeClient` is a temporary adapter that preserves migration safety: + +1. It wraps the existing LiteLLM router while emitting canonical response types. +2. It enables parity testing of canonical request/response contracts before native provider adapters are cut over. +3. It remains available as a rollback path during native adapter soak windows. + +PR-1 is intentionally non-invasive: + +1. No `ModelFacade` call-site behavior changes. +2. No provider routing changes. +3. No retry/throttle lifecycle migration yet. + +## Planned Follow-On + +In PR-2, `ModelFacade` will switch from direct router usage to `ModelClient` implementations, +consuming canonical responses from the bridge first, then native adapters. From 34349c7e1ff642be04bc6fc808d793473ada7b81 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 27 Feb 2026 17:10:11 -0700 Subject: [PATCH 08/27] self-review feedback + address --- .../models/clients/adapters/litellm_bridge.py | 5 +- .../engine/models/clients/base.py | 4 + .../engine/models/clients/errors.py | 15 +++ .../models/clients/test_client_errors.py | 48 +++++++- .../models/clients/test_litellm_bridge.py | 104 +++++++++++++++++- 5 files changed, 170 insertions(+), 6 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index 740269237..b21fbe84e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -189,10 +189,7 @@ def _image_request_kwargs(request: ImageGenerationRequest) -> dict[str, Any]: def _image_chat_kwargs(request: ImageGenerationRequest) -> dict[str, Any]: - kwargs = _image_request_kwargs(request) - if request.extra_body is not None: - kwargs["extra_body"] = request.extra_body - return kwargs + return _image_request_kwargs(request) def _extract_embedding_vector(item: Any) -> list[float]: diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/base.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/base.py index 516797e01..d1b5cd23a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/base.py @@ -41,3 +41,7 @@ def supports_chat_completion(self) -> bool: ... def supports_embeddings(self) -> bool: ... def supports_image_generation(self) -> bool: ... + + def close(self) -> None: ... + + async def aclose(self) -> None: ... diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index b8753c0af..02ad4117d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -33,6 +33,11 @@ class ProviderError(Exception): model_name: str | None = None cause: Exception | None = None + def __post_init__(self) -> None: + Exception.__init__(self, self.message) + if self.cause is not None: + self.__cause__ = self.cause + def __str__(self) -> str: return self.message @@ -107,10 +112,20 @@ def map_http_error_to_provider_error( def _extract_response_text(response: Any) -> str: + # Try structured JSON extraction first — most providers return structured error + # bodies and we want the human-readable message, not raw JSON. + structured = _extract_structured_message(response) + if structured: + return structured + response_text = getattr(response, "text", None) if isinstance(response_text, str) and response_text.strip(): return response_text.strip() + return "" + + +def _extract_structured_message(response: Any) -> str: try: payload = response.json() except Exception: diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py index 6532a2900..7ca75f9d6 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py @@ -52,7 +52,7 @@ def test_map_http_status_to_provider_error_kind( assert map_http_status_to_provider_error_kind(status_code=status_code, body_text=body_text) == expected_kind -def test_map_http_error_to_provider_error_uses_text_payload() -> None: +def test_map_http_error_to_provider_error_uses_text_when_no_json() -> None: response = StubHttpResponse(status_code=429, text="Rate limit hit") error = map_http_error_to_provider_error(response=response, provider_name="stub-provider", model_name="stub-model") assert isinstance(error, ProviderError) @@ -63,6 +63,22 @@ def test_map_http_error_to_provider_error_uses_text_payload() -> None: assert error.model_name == "stub-model" +def test_map_http_error_to_provider_error_prefers_json_over_raw_text() -> None: + response = StubHttpResponse( + status_code=400, + text='{"error": {"type": "invalid_request_error", "message": "Context too long."}}', + json_payload={ + "error": { + "type": "invalid_request_error", + "message": "Context too long.", + } + }, + ) + error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") + assert error.kind == ProviderErrorKind.BAD_REQUEST + assert error.message == "Context too long." + + def test_map_http_error_to_provider_error_uses_json_payload_when_text_missing() -> None: response = StubHttpResponse( status_code=403, @@ -99,3 +115,33 @@ def test_provider_error_unsupported_capability_helper() -> None: assert error.kind == ProviderErrorKind.UNSUPPORTED_CAPABILITY assert "image-generation" in error.message assert "stub-model" in error.message + + +def test_provider_error_chains_cause_exception() -> None: + original = RuntimeError("connection reset") + error = ProviderError( + kind=ProviderErrorKind.API_CONNECTION, + message="Connection failed", + cause=original, + ) + assert error.__cause__ is original + assert str(error) == "Connection failed" + + +def test_provider_error_without_cause_has_no_chain() -> None: + error = ProviderError( + kind=ProviderErrorKind.RATE_LIMIT, + message="Too many requests", + ) + assert error.__cause__ is None + + +def test_provider_error_is_catchable_as_exception() -> None: + error = ProviderError( + kind=ProviderErrorKind.AUTHENTICATION, + message="Invalid API key", + ) + with pytest.raises(ProviderError) as exc_info: + raise error + assert exc_info.value.kind == ProviderErrorKind.AUTHENTICATION + assert str(exc_info.value) == "Invalid API key" diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py index 23c78ce1a..24df83b56 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py @@ -196,9 +196,111 @@ def test_generate_image_uses_diffusion_path_without_messages() -> None: router.image_generation.assert_called_once_with(prompt="make an image", model="stub-model", n=2) +@pytest.mark.asyncio +async def test_aembeddings_maps_vectors_and_usage() -> None: + response = SimpleNamespace( + data=[{"embedding": [0.1, 0.2]}, SimpleNamespace(embedding=[0.3, 0.4])], + usage=SimpleNamespace(prompt_tokens=5, total_tokens=5), + ) + router = MagicMock() + router.aembedding = AsyncMock(return_value=response) + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + + request = EmbeddingRequest(model="stub-model", inputs=["x", "y"]) + result = await client.aembeddings(request) + + assert result.vectors == [[0.1, 0.2], [0.3, 0.4]] + assert result.usage is not None + assert result.usage.input_tokens == 5 + assert result.raw is response + router.aembedding.assert_awaited_once_with(model="stub-model", input=["x", "y"]) + + +def test_completion_coerces_list_content_blocks_to_string() -> None: + response = _build_chat_response( + content=[{"type": "text", "text": "first"}, {"type": "text", "text": "second"}], + reasoning_content=None, + tool_calls=[], + usage=None, + ) + router = MagicMock() + router.completion.return_value = response + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + + request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) + result = client.completion(request) + + assert result.message.content == "first\nsecond" + + +def test_close_and_aclose_are_callable() -> None: + router = MagicMock() + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + client.close() + + +@pytest.mark.asyncio +async def test_aclose_is_callable() -> None: + router = MagicMock() + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + await client.aclose() + + +@pytest.mark.asyncio +async def test_agenerate_image_uses_diffusion_path_without_messages() -> None: + response = SimpleNamespace( + data=[SimpleNamespace(b64_json="YXN5bmM=")], + usage=SimpleNamespace(input_tokens=3, output_tokens=7), + ) + router = MagicMock() + router.aimage_generation = AsyncMock(return_value=response) + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + + request = ImageGenerationRequest(model="stub-model", prompt="async image", n=1) + result = await client.agenerate_image(request) + + assert len(result.images) == 1 + assert result.images[0].b64_data == "YXN5bmM=" + assert result.usage is not None + assert result.usage.generated_images == 1 + router.aimage_generation.assert_awaited_once_with(prompt="async image", model="stub-model", n=1) + + +def test_completion_with_empty_choices_returns_empty_message() -> None: + response = SimpleNamespace(choices=[], usage=None) + router = MagicMock() + router.completion.return_value = response + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + + request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) + result = client.completion(request) + + assert result.message.content is None + assert result.message.tool_calls == [] + assert result.message.images == [] + + +def test_completion_with_tool_call_dict_arguments() -> None: + response = _build_chat_response( + content=None, + reasoning_content=None, + tool_calls=[{"id": "call-2", "function": {"name": "search", "arguments": {"q": "test"}}}], + usage=None, + ) + router = MagicMock() + router.completion.return_value = response + client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + + request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) + result = client.completion(request) + + assert len(result.message.tool_calls) == 1 + assert result.message.tool_calls[0].arguments_json == '{"q": "test"}' + + def _build_chat_response( *, - content: str | None, + content: Any, reasoning_content: str | None, tool_calls: list[dict[str, Any]] | None, usage: Any, From 6aae4b6c8b159702018d067eba10d2675e80029a Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 27 Feb 2026 17:17:37 -0700 Subject: [PATCH 09/27] add LiteLLMRouter protocol to strongly type bridge router param Co-Authored-By: Claude Opus 4.6 --- .../models/clients/adapters/__init__.py | 4 ++-- .../models/clients/adapters/litellm_bridge.py | 20 +++++++++++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py index 0ecbae287..1b65e2dde 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient +from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient, LiteLLMRouter -__all__ = ["LiteLLMBridgeClient"] +__all__ = ["LiteLLMBridgeClient", "LiteLLMRouter"] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index b21fbe84e..8d7a88f4a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -5,7 +5,7 @@ import json import logging -from typing import Any +from typing import Any, Protocol from data_designer.config.utils.image_helpers import ( extract_base64_from_data_uri, @@ -29,10 +29,26 @@ logger = logging.getLogger(__name__) +class LiteLLMRouter(Protocol): + """Structural type for the LiteLLM router methods the bridge depends on.""" + + def completion(self, *, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> Any: ... + + async def acompletion(self, *, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> Any: ... + + def embedding(self, *, model: str, input: list[str], **kwargs: Any) -> Any: ... + + async def aembedding(self, *, model: str, input: list[str], **kwargs: Any) -> Any: ... + + def image_generation(self, *, prompt: str, model: str, **kwargs: Any) -> Any: ... + + async def aimage_generation(self, *, prompt: str, model: str, **kwargs: Any) -> Any: ... + + class LiteLLMBridgeClient(ModelClient): """Bridge adapter that wraps the existing LiteLLM router behind canonical client types.""" - def __init__(self, *, provider_name: str, router: Any) -> None: + def __init__(self, *, provider_name: str, router: LiteLLMRouter) -> None: self.provider_name = provider_name self._router = router From 2a53d373de87099e825f95c4d1637d89b75a82e9 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 27 Feb 2026 17:25:51 -0700 Subject: [PATCH 10/27] simplify some things --- .../models/clients/adapters/litellm_bridge.py | 82 ++++++------------- 1 file changed, 23 insertions(+), 59 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index 8d7a88f4a..fff7b207d 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -3,6 +3,7 @@ from __future__ import annotations +import dataclasses import json import logging from typing import Any, Protocol @@ -48,6 +49,11 @@ async def aimage_generation(self, *, prompt: str, model: str, **kwargs: Any) -> class LiteLLMBridgeClient(ModelClient): """Bridge adapter that wraps the existing LiteLLM router behind canonical client types.""" + # "messages" and "prompt" have None defaults but are passed explicitly to choose + # between the chat-completion and diffusion code paths, so exclude them from the + # automatic optional-field forwarding. + _IMAGE_EXCLUDE = frozenset({"messages", "prompt"}) + def __init__(self, *, provider_name: str, router: LiteLLMRouter) -> None: self.provider_name = provider_name self._router = router @@ -65,7 +71,7 @@ def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: response = self._router.completion( model=request.model, messages=request.messages, - **_chat_request_kwargs(request), + **_collect_non_none_optional_fields(request), ) return _parse_chat_completion_response(response) @@ -73,7 +79,7 @@ async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionRes response = await self._router.acompletion( model=request.model, messages=request.messages, - **_chat_request_kwargs(request), + **_collect_non_none_optional_fields(request), ) return _parse_chat_completion_response(response) @@ -81,7 +87,7 @@ def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: response = self._router.embedding( model=request.model, input=request.inputs, - **_embedding_request_kwargs(request), + **_collect_non_none_optional_fields(request), ) vectors = [_extract_embedding_vector(item) for item in getattr(response, "data", [])] return EmbeddingResponse(vectors=vectors, usage=_extract_usage(getattr(response, "usage", None)), raw=response) @@ -90,24 +96,25 @@ async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: response = await self._router.aembedding( model=request.model, input=request.inputs, - **_embedding_request_kwargs(request), + **_collect_non_none_optional_fields(request), ) vectors = [_extract_embedding_vector(item) for item in getattr(response, "data", [])] return EmbeddingResponse(vectors=vectors, usage=_extract_usage(getattr(response, "usage", None)), raw=response) def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + image_kwargs = _collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) if request.messages is not None: response = self._router.completion( model=request.model, messages=request.messages, - **_image_chat_kwargs(request), + **image_kwargs, ) images = _extract_images_from_chat_response(response) else: response = self._router.image_generation( prompt=request.prompt, model=request.model, - **_image_request_kwargs(request), + **image_kwargs, ) images = _extract_images_from_image_response(response) @@ -115,18 +122,19 @@ def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResp return ImageGenerationResponse(images=images, usage=usage, raw=response) async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + image_kwargs = _collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) if request.messages is not None: response = await self._router.acompletion( model=request.model, messages=request.messages, - **_image_chat_kwargs(request), + **image_kwargs, ) images = _extract_images_from_chat_response(response) else: response = await self._router.aimage_generation( prompt=request.prompt, model=request.model, - **_image_request_kwargs(request), + **image_kwargs, ) images = _extract_images_from_image_response(response) @@ -155,57 +163,13 @@ def _parse_chat_completion_response(response: Any) -> ChatCompletionResponse: return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) -def _chat_request_kwargs(request: ChatCompletionRequest) -> dict[str, Any]: - kwargs: dict[str, Any] = {} - if request.tools is not None: - kwargs["tools"] = request.tools - if request.temperature is not None: - kwargs["temperature"] = request.temperature - if request.top_p is not None: - kwargs["top_p"] = request.top_p - if request.max_tokens is not None: - kwargs["max_tokens"] = request.max_tokens - if request.timeout is not None: - kwargs["timeout"] = request.timeout - if request.extra_body is not None: - kwargs["extra_body"] = request.extra_body - if request.extra_headers is not None: - kwargs["extra_headers"] = request.extra_headers - if request.metadata is not None: - kwargs["metadata"] = request.metadata - return kwargs - - -def _embedding_request_kwargs(request: EmbeddingRequest) -> dict[str, Any]: - kwargs: dict[str, Any] = {} - if request.encoding_format is not None: - kwargs["encoding_format"] = request.encoding_format - if request.dimensions is not None: - kwargs["dimensions"] = request.dimensions - if request.timeout is not None: - kwargs["timeout"] = request.timeout - if request.extra_body is not None: - kwargs["extra_body"] = request.extra_body - if request.extra_headers is not None: - kwargs["extra_headers"] = request.extra_headers - return kwargs - - -def _image_request_kwargs(request: ImageGenerationRequest) -> dict[str, Any]: - kwargs: dict[str, Any] = {} - if request.n is not None: - kwargs["n"] = request.n - if request.timeout is not None: - kwargs["timeout"] = request.timeout - if request.extra_body is not None: - kwargs["extra_body"] = request.extra_body - if request.extra_headers is not None: - kwargs["extra_headers"] = request.extra_headers - return kwargs - - -def _image_chat_kwargs(request: ImageGenerationRequest) -> dict[str, Any]: - return _image_request_kwargs(request) +def _collect_non_none_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: + """Extract non-None optional fields from a request dataclass, skipping *exclude*.""" + return { + f.name: v + for f in dataclasses.fields(request) + if f.name not in exclude and f.default is None and (v := getattr(request, f.name)) is not None + } def _extract_embedding_vector(item: Any) -> list[float]: From 4e2f3afbd7f6999d2aad09db96163b55fc67c0ec Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 27 Feb 2026 17:33:06 -0700 Subject: [PATCH 11/27] add a protol for http response like object --- .../engine/models/clients/__init__.py | 2 ++ .../engine/models/clients/errors.py | 17 +++++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index 4abdde859..d1fdc4290 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -3,6 +3,7 @@ from data_designer.engine.models.clients.base import ModelClient from data_designer.engine.models.clients.errors import ( + HttpResponse, ProviderError, ProviderErrorKind, map_http_error_to_provider_error, @@ -22,6 +23,7 @@ ) __all__ = [ + "HttpResponse", "AssistantMessage", "ChatCompletionRequest", "ChatCompletionResponse", diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index 02ad4117d..9f14a73f8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -5,7 +5,16 @@ from dataclasses import dataclass from enum import Enum -from typing import Any +from typing import Any, Protocol + + +class HttpResponse(Protocol): + """Structural type for HTTP response objects (httpx, requests, etc.).""" + + status_code: int + text: str + + def json(self) -> Any: ... class ProviderErrorKind(str, Enum): @@ -86,7 +95,7 @@ def map_http_status_to_provider_error_kind(status_code: int, body_text: str = "" def map_http_error_to_provider_error( *, - response: Any, + response: HttpResponse, provider_name: str, model_name: str | None = None, ) -> ProviderError: @@ -111,7 +120,7 @@ def map_http_error_to_provider_error( ) -def _extract_response_text(response: Any) -> str: +def _extract_response_text(response: HttpResponse) -> str: # Try structured JSON extraction first — most providers return structured error # bodies and we want the human-readable message, not raw JSON. structured = _extract_structured_message(response) @@ -125,7 +134,7 @@ def _extract_response_text(response: Any) -> str: return "" -def _extract_structured_message(response: Any) -> str: +def _extract_structured_message(response: HttpResponse) -> str: try: payload = response.json() except Exception: From b1c85f2adc5f12409fcf464bbb6f5ccc00234bf5 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 27 Feb 2026 17:35:19 -0700 Subject: [PATCH 12/27] move HttpResponse --- .../data_designer/engine/models/clients/__init__.py | 2 +- .../src/data_designer/engine/models/clients/errors.py | 10 +--------- .../src/data_designer/engine/models/clients/types.py | 11 ++++++++++- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index d1fdc4290..dec52401a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -3,7 +3,6 @@ from data_designer.engine.models.clients.base import ModelClient from data_designer.engine.models.clients.errors import ( - HttpResponse, ProviderError, ProviderErrorKind, map_http_error_to_provider_error, @@ -15,6 +14,7 @@ ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse, + HttpResponse, ImageGenerationRequest, ImageGenerationResponse, ImagePayload, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index 9f14a73f8..267226485 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -5,16 +5,8 @@ from dataclasses import dataclass from enum import Enum -from typing import Any, Protocol - -class HttpResponse(Protocol): - """Structural type for HTTP response objects (httpx, requests, etc.).""" - - status_code: int - text: str - - def json(self) -> Any: ... +from data_designer.engine.models.clients.types import HttpResponse class ProviderErrorKind(str, Enum): diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py index 1d269b0a7..3df379910 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py @@ -4,7 +4,16 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any +from typing import Any, Protocol + + +class HttpResponse(Protocol): + """Structural type for HTTP response objects (httpx, requests, etc.).""" + + status_code: int + text: str + + def json(self) -> Any: ... @dataclass From f6dc769172b28611709daa9c03f19a42e56fb61f Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 27 Feb 2026 17:38:02 -0700 Subject: [PATCH 13/27] update PR-1 architecture notes for lifecycle and router protocol Co-Authored-By: Claude Opus 4.6 --- plans/343/model-facade-overhaul-pr-1-architecture-notes.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plans/343/model-facade-overhaul-pr-1-architecture-notes.md b/plans/343/model-facade-overhaul-pr-1-architecture-notes.md index 2391dd0a7..6667d04dc 100644 --- a/plans/343/model-facade-overhaul-pr-1-architecture-notes.md +++ b/plans/343/model-facade-overhaul-pr-1-architecture-notes.md @@ -20,6 +20,7 @@ Boundary contract: 1. `ModelFacade`-facing requests/responses use canonical dataclasses in `clients/types.py`. 2. Provider SDK and transport-specific response shapes do not leak above the adapter layer. 3. Provider failures normalize to canonical provider errors (`ProviderError`, `ProviderErrorKind`). +4. All adapters implement `close`/`aclose` lifecycle methods (defined on `ModelClient` protocol) for deterministic resource teardown. Canonical operation types in PR-1: @@ -31,7 +32,7 @@ Canonical operation types in PR-1: `LiteLLMBridgeClient` is a temporary adapter that preserves migration safety: -1. It wraps the existing LiteLLM router while emitting canonical response types. +1. It wraps the existing LiteLLM router while emitting canonical response types. The router dependency is typed via the `LiteLLMRouter` protocol — a structural type that defines only the six methods the bridge calls, without importing anything from LiteLLM. 2. It enables parity testing of canonical request/response contracts before native provider adapters are cut over. 3. It remains available as a rollback path during native adapter soak windows. From ec5ed9b259e21203a896033245adb2d6cd52e820 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 4 Mar 2026 09:55:10 -0700 Subject: [PATCH 14/27] Address PR #359 feedback: exception wrapping, shared parsing, test improvements - Wrap all LiteLLM router calls in try/except to normalize raw exceptions into canonical ProviderError at the bridge boundary (blocking review item) - Extract reusable response-parsing helpers into clients/parsing.py for shared use across future native adapters - Add async image parsing path using httpx.AsyncClient to avoid blocking the event loop in agenerate_image - Add retry_after field to ProviderError for future retry engine support - Fix _to_int_or_none to parse numeric strings from providers - Create test conftest.py with shared mock_router/bridge_client fixtures - Parametrize duplicate image generation and error mapping tests - Add tests for exception wrapping across all bridge methods --- .../config/utils/image_helpers.py | 19 + .../models/clients/adapters/litellm_bridge.py | 407 ++++++------------ .../engine/models/clients/errors.py | 20 + .../engine/models/clients/parsing.py | 318 ++++++++++++++ .../tests/engine/models/clients/conftest.py | 20 + .../models/clients/test_client_errors.py | 128 +++--- .../models/clients/test_litellm_bridge.py | 307 ++++++++----- 7 files changed, 788 insertions(+), 431 deletions(-) create mode 100644 packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py create mode 100644 packages/data-designer-engine/tests/engine/models/clients/conftest.py diff --git a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py index 1f0e5c919..934be5b43 100644 --- a/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py +++ b/packages/data-designer-config/src/data_designer/config/utils/image_helpers.py @@ -260,6 +260,25 @@ def load_image_url_to_base64(url: str, timeout: int = 60) -> str: return base64.b64encode(resp.content).decode() +async def aload_image_url_to_base64(url: str, timeout: int = 60) -> str: + """Download an image from a URL asynchronously and return as base64. + + Args: + url: HTTP(S) URL pointing to an image. + timeout: Request timeout in seconds. + + Returns: + Base64-encoded image data. + + Raises: + httpx.HTTPStatusError: If the download fails with a non-2xx status. + """ + async with lazy.httpx.AsyncClient() as client: + resp = await client.get(url, timeout=timeout) + resp.raise_for_status() + return base64.b64encode(resp.content).decode() + + def validate_image(image_path: Path) -> None: """Validate that an image file is readable and not corrupted. diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index fff7b207d..09194aa4b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -3,28 +3,28 @@ from __future__ import annotations -import dataclasses -import json import logging from typing import Any, Protocol -from data_designer.config.utils.image_helpers import ( - extract_base64_from_data_uri, - is_base64_image, - load_image_url_to_base64, -) from data_designer.engine.models.clients.base import ModelClient +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind +from data_designer.engine.models.clients.parsing import ( + aextract_images_from_chat_response, + aextract_images_from_image_response, + collect_non_none_optional_fields, + extract_embedding_vector, + extract_images_from_chat_response, + extract_images_from_image_response, + extract_usage, + parse_chat_completion_response, +) from data_designer.engine.models.clients.types import ( - AssistantMessage, ChatCompletionRequest, ChatCompletionResponse, EmbeddingRequest, EmbeddingResponse, ImageGenerationRequest, ImageGenerationResponse, - ImagePayload, - ToolCall, - Usage, ) logger = logging.getLogger(__name__) @@ -68,77 +68,107 @@ def supports_image_generation(self) -> bool: return True def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - response = self._router.completion( - model=request.model, - messages=request.messages, - **_collect_non_none_optional_fields(request), - ) - return _parse_chat_completion_response(response) - - async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - response = await self._router.acompletion( - model=request.model, - messages=request.messages, - **_collect_non_none_optional_fields(request), - ) - return _parse_chat_completion_response(response) - - def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - response = self._router.embedding( - model=request.model, - input=request.inputs, - **_collect_non_none_optional_fields(request), - ) - vectors = [_extract_embedding_vector(item) for item in getattr(response, "data", [])] - return EmbeddingResponse(vectors=vectors, usage=_extract_usage(getattr(response, "usage", None)), raw=response) - - async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - response = await self._router.aembedding( - model=request.model, - input=request.inputs, - **_collect_non_none_optional_fields(request), - ) - vectors = [_extract_embedding_vector(item) for item in getattr(response, "data", [])] - return EmbeddingResponse(vectors=vectors, usage=_extract_usage(getattr(response, "usage", None)), raw=response) - - def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - image_kwargs = _collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) - if request.messages is not None: + try: response = self._router.completion( model=request.model, messages=request.messages, - **image_kwargs, - ) - images = _extract_images_from_chat_response(response) - else: - response = self._router.image_generation( - prompt=request.prompt, - model=request.model, - **image_kwargs, + **collect_non_none_optional_fields(request), ) - images = _extract_images_from_image_response(response) - - usage = _extract_usage(getattr(response, "usage", None), generated_images=len(images)) - return ImageGenerationResponse(images=images, usage=usage, raw=response) + except ProviderError: + raise + except Exception as exc: + raise _wrap_router_error(exc, provider_name=self.provider_name) from exc + return parse_chat_completion_response(response) - async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - image_kwargs = _collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) - if request.messages is not None: + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + try: response = await self._router.acompletion( model=request.model, messages=request.messages, - **image_kwargs, + **collect_non_none_optional_fields(request), + ) + except ProviderError: + raise + except Exception as exc: + raise _wrap_router_error(exc, provider_name=self.provider_name) from exc + return parse_chat_completion_response(response) + + def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + try: + response = self._router.embedding( + model=request.model, + input=request.inputs, + **collect_non_none_optional_fields(request), ) - images = _extract_images_from_chat_response(response) - else: - response = await self._router.aimage_generation( - prompt=request.prompt, + except ProviderError: + raise + except Exception as exc: + raise _wrap_router_error(exc, provider_name=self.provider_name) from exc + vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])] + return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response) + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + try: + response = await self._router.aembedding( model=request.model, - **image_kwargs, + input=request.inputs, + **collect_non_none_optional_fields(request), ) - images = _extract_images_from_image_response(response) + except ProviderError: + raise + except Exception as exc: + raise _wrap_router_error(exc, provider_name=self.provider_name) from exc + vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])] + return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response) - usage = _extract_usage(getattr(response, "usage", None), generated_images=len(images)) + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) + try: + if request.messages is not None: + response = self._router.completion( + model=request.model, + messages=request.messages, + **image_kwargs, + ) + images = extract_images_from_chat_response(response) + else: + response = self._router.image_generation( + prompt=request.prompt, + model=request.model, + **image_kwargs, + ) + images = extract_images_from_image_response(response) + except ProviderError: + raise + except Exception as exc: + raise _wrap_router_error(exc, provider_name=self.provider_name) from exc + + usage = extract_usage(getattr(response, "usage", None), generated_images=len(images)) + return ImageGenerationResponse(images=images, usage=usage, raw=response) + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) + try: + if request.messages is not None: + response = await self._router.acompletion( + model=request.model, + messages=request.messages, + **image_kwargs, + ) + images = await aextract_images_from_chat_response(response) + else: + response = await self._router.aimage_generation( + prompt=request.prompt, + model=request.model, + **image_kwargs, + ) + images = await aextract_images_from_image_response(response) + except ProviderError: + raise + except Exception as exc: + raise _wrap_router_error(exc, provider_name=self.provider_name) from exc + + usage = extract_usage(getattr(response, "usage", None), generated_images=len(images)) return ImageGenerationResponse(images=images, usage=usage, raw=response) def close(self) -> None: @@ -148,219 +178,34 @@ async def aclose(self) -> None: return None -def _parse_chat_completion_response(response: Any) -> ChatCompletionResponse: - first_choice = _first_or_none(getattr(response, "choices", None)) - message = _value_from(first_choice, "message") - tool_calls = _extract_tool_calls(_value_from(message, "tool_calls")) - images = _extract_images_from_chat_message(message) - assistant_message = AssistantMessage( - content=_coerce_message_content(_value_from(message, "content")), - reasoning_content=_value_from(message, "reasoning_content"), - tool_calls=tool_calls, - images=images, - ) - usage = _extract_usage(getattr(response, "usage", None), generated_images=len(images) if images else None) - return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) - - -def _collect_non_none_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: - """Extract non-None optional fields from a request dataclass, skipping *exclude*.""" - return { - f.name: v - for f in dataclasses.fields(request) - if f.name not in exclude and f.default is None and (v := getattr(request, f.name)) is not None - } - - -def _extract_embedding_vector(item: Any) -> list[float]: - value = _value_from(item, "embedding") - if isinstance(value, list): - return [float(v) for v in value] - return [] - - -def _extract_tool_calls(raw_tool_calls: Any) -> list[ToolCall]: - if not raw_tool_calls: - return [] - - normalized_tool_calls: list[ToolCall] = [] - for raw_tool_call in raw_tool_calls: - tool_call_id = _value_from(raw_tool_call, "id") or "" - function = _value_from(raw_tool_call, "function") - name = _value_from(function, "name") or "" - arguments_value = _value_from(function, "arguments") - arguments_json = _serialize_tool_arguments(arguments_value) - normalized_tool_calls.append(ToolCall(id=str(tool_call_id), name=str(name), arguments_json=arguments_json)) - - return normalized_tool_calls - - -def _serialize_tool_arguments(arguments_value: Any) -> str: - if arguments_value is None: - return "{}" - if isinstance(arguments_value, str): - return arguments_value - try: - return json.dumps(arguments_value) - except Exception: - return str(arguments_value) - - -def _extract_images_from_chat_response(response: Any) -> list[ImagePayload]: - first_choice = _first_or_none(getattr(response, "choices", None)) - message = _value_from(first_choice, "message") - return _extract_images_from_chat_message(message) - - -def _extract_images_from_chat_message(message: Any) -> list[ImagePayload]: - images: list[ImagePayload] = [] - - raw_images = _value_from(message, "images") - if isinstance(raw_images, list): - for raw_image in raw_images: - parsed_image = _parse_image_payload(raw_image) - if parsed_image is not None: - images.append(parsed_image) - - if images: - return images - - raw_content = _value_from(message, "content") - if isinstance(raw_content, str): - parsed_image = _parse_image_payload(raw_content) - if parsed_image is not None: - images.append(parsed_image) - - return images - - -def _extract_images_from_image_response(response: Any) -> list[ImagePayload]: - images: list[ImagePayload] = [] - for raw_image in getattr(response, "data", []): - parsed_image = _parse_image_payload(raw_image) - if parsed_image is not None: - images.append(parsed_image) - return images - - -def _parse_image_payload(raw_image: Any) -> ImagePayload | None: - try: - if isinstance(raw_image, str): - return _parse_image_string(raw_image) - - if isinstance(raw_image, dict): - if "b64_json" in raw_image and isinstance(raw_image["b64_json"], str): - return ImagePayload(b64_data=raw_image["b64_json"], mime_type=None) - if "image_url" in raw_image: - return _parse_image_payload(raw_image["image_url"]) - if "url" in raw_image and isinstance(raw_image["url"], str): - return _parse_image_string(raw_image["url"]) - - b64_json = _value_from(raw_image, "b64_json") - if isinstance(b64_json, str): - return ImagePayload(b64_data=b64_json, mime_type=None) - - url = _value_from(raw_image, "url") - if isinstance(url, str): - return _parse_image_string(url) - except Exception: - logger.debug("Unable to parse image payload from bridge response object.", exc_info=True) - - return None - - -def _parse_image_string(raw_value: str) -> ImagePayload | None: - if raw_value.startswith("data:image/"): - return ImagePayload( - b64_data=extract_base64_from_data_uri(raw_value), - mime_type=_extract_mime_type_from_data_uri(raw_value), - ) - - if is_base64_image(raw_value): - return ImagePayload(b64_data=raw_value, mime_type=None) - - if raw_value.startswith(("http://", "https://")): - b64_data = load_image_url_to_base64(raw_value) - return ImagePayload(b64_data=b64_data, mime_type=None) - - return None - - -def _extract_mime_type_from_data_uri(data_uri: str) -> str | None: - if not data_uri.startswith("data:"): - return None - head = data_uri.split(",", maxsplit=1)[0] - if ";" in head: - return head[5:].split(";", maxsplit=1)[0] - return head[5:] or None - - -def _extract_usage(raw_usage: Any, generated_images: int | None = None) -> Usage | None: - if raw_usage is None and generated_images is None: - return None - - input_tokens = _value_from(raw_usage, "prompt_tokens") - output_tokens = _value_from(raw_usage, "completion_tokens") - total_tokens = _value_from(raw_usage, "total_tokens") - - if input_tokens is None: - input_tokens = _value_from(raw_usage, "input_tokens") - if output_tokens is None: - output_tokens = _value_from(raw_usage, "output_tokens") - - if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int): - total_tokens = input_tokens + output_tokens +def _wrap_router_error(exc: Exception, *, provider_name: str) -> ProviderError: + """Normalize a raw router/LiteLLM exception into a canonical ProviderError.""" + status_code = getattr(exc, "status_code", None) + if isinstance(status_code, int): + from data_designer.engine.models.clients.errors import map_http_status_to_provider_error_kind - if generated_images is None: - generated_images = _value_from(raw_usage, "generated_images") - if generated_images is None and raw_usage is not None: - generated_images = _value_from(raw_usage, "images") - - if input_tokens is None and output_tokens is None and total_tokens is None and generated_images is None: - return None + kind = map_http_status_to_provider_error_kind(status_code=status_code, body_text=str(exc)) + else: + kind = _infer_error_kind(exc) - return Usage( - input_tokens=_to_int_or_none(input_tokens), - output_tokens=_to_int_or_none(output_tokens), - total_tokens=_to_int_or_none(total_tokens), - generated_images=_to_int_or_none(generated_images), + return ProviderError( + kind=kind, + message=str(exc), + status_code=status_code if isinstance(status_code, int) else None, + provider_name=provider_name, + cause=exc, ) -def _coerce_message_content(content: Any) -> str | None: - if content is None: - return None - if isinstance(content, str): - return content - if isinstance(content, list): - text_parts: list[str] = [] - for block in content: - if isinstance(block, dict): - text_value = block.get("text") - if isinstance(text_value, str): - text_parts.append(text_value) - if text_parts: - return "\n".join(text_parts) - return str(content) - - -def _to_int_or_none(value: Any) -> int | None: - if isinstance(value, int): - return value - if isinstance(value, float): - return int(value) - return None - - -def _value_from(source: Any, key: str) -> Any: - if source is None: - return None - if isinstance(source, dict): - return source.get(key) - return getattr(source, key, None) - - -def _first_or_none(values: Any) -> Any | None: - if isinstance(values, list) and values: - return values[0] - return None +def _infer_error_kind(exc: Exception) -> ProviderErrorKind: + """Infer error kind from exception type name when no status code is available.""" + type_name = type(exc).__name__.lower() + if "timeout" in type_name: + return ProviderErrorKind.TIMEOUT + if "connection" in type_name or "connect" in type_name: + return ProviderErrorKind.API_CONNECTION + if "auth" in type_name: + return ProviderErrorKind.AUTHENTICATION + if "ratelimit" in type_name: + return ProviderErrorKind.RATE_LIMIT + return ProviderErrorKind.API_ERROR diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index 267226485..402fbd742 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -32,6 +32,7 @@ class ProviderError(Exception): status_code: int | None = None provider_name: str | None = None model_name: str | None = None + retry_after: float | None = None cause: Exception | None = None def __post_init__(self) -> None: @@ -103,12 +104,14 @@ def map_http_error_to_provider_error( ) kind = map_http_status_to_provider_error_kind(status_code=status_code, body_text=body_text) + retry_after = _extract_retry_after(response) if status_code == 429 else None return ProviderError( kind=kind, message=body_text or f"Provider {provider_name!r} request failed with status code {status_code}.", status_code=status_code, provider_name=provider_name, model_name=model_name, + retry_after=retry_after, ) @@ -144,6 +147,23 @@ def _extract_structured_message(response: HttpResponse) -> str: return "" +def _extract_retry_after(response: HttpResponse) -> float | None: + headers = getattr(response, "headers", None) + if headers is None: + return None + raw = ( + headers.get("retry-after") + if isinstance(headers, dict) + else getattr(headers, "get", lambda _: None)("retry-after") + ) + if raw is None: + return None + try: + return float(raw) + except (ValueError, TypeError): + return None + + def _looks_like_context_window_error(text: str) -> bool: return any( token in text diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py new file mode 100644 index 000000000..561240300 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -0,0 +1,318 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared response-parsing helpers reusable across provider adapters.""" + +from __future__ import annotations + +import dataclasses +import json +import logging +from typing import Any + +from data_designer.config.utils.image_helpers import ( + aload_image_url_to_base64, + extract_base64_from_data_uri, + is_base64_image, + load_image_url_to_base64, +) +from data_designer.engine.models.clients.types import ( + AssistantMessage, + ChatCompletionResponse, + ImagePayload, + ToolCall, + Usage, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# High-level response parsers +# --------------------------------------------------------------------------- + + +def parse_chat_completion_response(response: Any) -> ChatCompletionResponse: + first_choice = get_first_value_or_none(getattr(response, "choices", None)) + message = get_value_from(first_choice, "message") + tool_calls = extract_tool_calls(get_value_from(message, "tool_calls")) + images = extract_images_from_chat_message(message) + assistant_message = AssistantMessage( + content=coerce_message_content(get_value_from(message, "content")), + reasoning_content=get_value_from(message, "reasoning_content"), + tool_calls=tool_calls, + images=images, + ) + usage = extract_usage(getattr(response, "usage", None), generated_images=len(images) if images else None) + return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) + + +# --------------------------------------------------------------------------- +# Image extraction +# --------------------------------------------------------------------------- + + +def extract_images_from_chat_response(response: Any) -> list[ImagePayload]: + first_choice = get_first_value_or_none(getattr(response, "choices", None)) + message = get_value_from(first_choice, "message") + return extract_images_from_chat_message(message) + + +async def aextract_images_from_chat_response(response: Any) -> list[ImagePayload]: + first_choice = get_first_value_or_none(getattr(response, "choices", None)) + message = get_value_from(first_choice, "message") + return await aextract_images_from_chat_message(message) + + +def extract_images_from_chat_message(message: Any) -> list[ImagePayload]: + primary, fallback = collect_raw_image_candidates(message) + images = parse_image_list(primary) + return images if images else parse_image_list(fallback) + + +async def aextract_images_from_chat_message(message: Any) -> list[ImagePayload]: + primary, fallback = collect_raw_image_candidates(message) + images = await aparse_image_list(primary) + return images if images else await aparse_image_list(fallback) + + +def extract_images_from_image_response(response: Any) -> list[ImagePayload]: + return parse_image_list(getattr(response, "data", [])) + + +async def aextract_images_from_image_response(response: Any) -> list[ImagePayload]: + return await aparse_image_list(getattr(response, "data", [])) + + +def collect_raw_image_candidates(message: Any) -> tuple[list[Any], list[Any]]: + """Return (primary, fallback) raw image candidates from a message.""" + primary: list[Any] = [] + raw_images = get_value_from(message, "images") + if isinstance(raw_images, list): + primary = list(raw_images) + + fallback: list[Any] = [] + raw_content = get_value_from(message, "content") + if isinstance(raw_content, str): + fallback = [raw_content] + + return primary, fallback + + +def parse_image_list(raw_items: list[Any]) -> list[ImagePayload]: + return [img for raw in raw_items if (img := parse_image_payload(raw)) is not None] + + +async def aparse_image_list(raw_items: list[Any]) -> list[ImagePayload]: + return [img for raw in raw_items if (img := await aparse_image_payload(raw)) is not None] + + +# --------------------------------------------------------------------------- +# Image payload parsing +# --------------------------------------------------------------------------- + + +def parse_image_payload(raw_image: Any) -> ImagePayload | None: + try: + result = resolve_image_payload(raw_image) + if isinstance(result, str): + return ImagePayload(b64_data=load_image_url_to_base64(result), mime_type=None) + return result + except Exception: + logger.debug("Unable to parse image payload from response object.", exc_info=True) + return None + + +async def aparse_image_payload(raw_image: Any) -> ImagePayload | None: + try: + result = resolve_image_payload(raw_image) + if isinstance(result, str): + return ImagePayload(b64_data=await aload_image_url_to_base64(result), mime_type=None) + return result + except Exception: + logger.debug("Unable to parse image payload from response object.", exc_info=True) + return None + + +def resolve_image_payload(raw_image: Any) -> ImagePayload | str | None: + """Resolve a raw image to an ImagePayload, a URL needing I/O, or None.""" + if isinstance(raw_image, str): + return resolve_image_string(raw_image) + + if isinstance(raw_image, dict): + if "b64_json" in raw_image and isinstance(raw_image["b64_json"], str): + return ImagePayload(b64_data=raw_image["b64_json"], mime_type=None) + if "image_url" in raw_image: + return resolve_image_payload(raw_image["image_url"]) + if "url" in raw_image and isinstance(raw_image["url"], str): + return resolve_image_string(raw_image["url"]) + + b64_json = get_value_from(raw_image, "b64_json") + if isinstance(b64_json, str): + return ImagePayload(b64_data=b64_json, mime_type=None) + + url = get_value_from(raw_image, "url") + if isinstance(url, str): + return resolve_image_string(url) + + return None + + +def resolve_image_string(raw_value: str) -> ImagePayload | str | None: + """Return an ImagePayload for inline data, a URL string for HTTP URLs, or None.""" + if raw_value.startswith("data:image/"): + return ImagePayload( + b64_data=extract_base64_from_data_uri(raw_value), + mime_type=extract_mime_type_from_data_uri(raw_value), + ) + + if is_base64_image(raw_value): + return ImagePayload(b64_data=raw_value, mime_type=None) + + if raw_value.startswith(("http://", "https://")): + return raw_value + + return None + + +# --------------------------------------------------------------------------- +# Tool call parsing +# --------------------------------------------------------------------------- + + +def extract_tool_calls(raw_tool_calls: Any) -> list[ToolCall]: + if not raw_tool_calls: + return [] + + normalized_tool_calls: list[ToolCall] = [] + for raw_tool_call in raw_tool_calls: + tool_call_id = get_value_from(raw_tool_call, "id") or "" + function = get_value_from(raw_tool_call, "function") + name = get_value_from(function, "name") or "" + arguments_value = get_value_from(function, "arguments") + arguments_json = serialize_tool_arguments(arguments_value) + normalized_tool_calls.append(ToolCall(id=str(tool_call_id), name=str(name), arguments_json=arguments_json)) + + return normalized_tool_calls + + +def serialize_tool_arguments(arguments_value: Any) -> str: + if arguments_value is None: + return "{}" + if isinstance(arguments_value, str): + return arguments_value + try: + return json.dumps(arguments_value) + except Exception: + return str(arguments_value) + + +# --------------------------------------------------------------------------- +# Usage & content helpers +# --------------------------------------------------------------------------- + + +def extract_usage(raw_usage: Any, generated_images: int | None = None) -> Usage | None: + if raw_usage is None and generated_images is None: + return None + + input_tokens = get_value_from(raw_usage, "prompt_tokens") + output_tokens = get_value_from(raw_usage, "completion_tokens") + total_tokens = get_value_from(raw_usage, "total_tokens") + + if input_tokens is None: + input_tokens = get_value_from(raw_usage, "input_tokens") + if output_tokens is None: + output_tokens = get_value_from(raw_usage, "output_tokens") + + if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int): + total_tokens = input_tokens + output_tokens + + if generated_images is None: + generated_images = get_value_from(raw_usage, "generated_images") + if generated_images is None and raw_usage is not None: + generated_images = get_value_from(raw_usage, "images") + + if input_tokens is None and output_tokens is None and total_tokens is None and generated_images is None: + return None + + return Usage( + input_tokens=coerce_to_int_or_none(input_tokens), + output_tokens=coerce_to_int_or_none(output_tokens), + total_tokens=coerce_to_int_or_none(total_tokens), + generated_images=coerce_to_int_or_none(generated_images), + ) + + +def extract_embedding_vector(item: Any) -> list[float]: + value = get_value_from(item, "embedding") + if isinstance(value, list): + return [float(v) for v in value] + return [] + + +def extract_mime_type_from_data_uri(data_uri: str) -> str | None: + if not data_uri.startswith("data:"): + return None + head = data_uri.split(",", maxsplit=1)[0] + if ";" in head: + return head[5:].split(";", maxsplit=1)[0] + return head[5:] or None + + +def coerce_message_content(content: Any) -> str | None: + if content is None: + return None + if isinstance(content, str): + return content + if isinstance(content, list): + text_parts: list[str] = [] + for block in content: + if isinstance(block, dict): + text_value = block.get("text") + if isinstance(text_value, str): + text_parts.append(text_value) + if text_parts: + return "\n".join(text_parts) + return str(content) + + +def coerce_to_int_or_none(value: Any) -> int | None: + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + if isinstance(value, str): + try: + return int(float(value)) + except (ValueError, OverflowError): + return None + return None + + +# --------------------------------------------------------------------------- +# Generic accessors +# --------------------------------------------------------------------------- + + +def get_value_from(source: Any, key: str) -> Any: + if source is None: + return None + if isinstance(source, dict): + return source.get(key) + return getattr(source, key, None) + + +def get_first_value_or_none(values: Any) -> Any | None: + if isinstance(values, list) and values: + return values[0] + return None + + +def collect_non_none_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: + """Extract non-None optional fields from a request dataclass, skipping *exclude*.""" + return { + f.name: v + for f in dataclasses.fields(request) + if f.name not in exclude and f.default is None and (v := getattr(request, f.name)) is not None + } diff --git a/packages/data-designer-engine/tests/engine/models/clients/conftest.py b/packages/data-designer-engine/tests/engine/models/clients/conftest.py new file mode 100644 index 000000000..6bfb439f2 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/clients/conftest.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient + + +@pytest.fixture +def mock_router() -> MagicMock: + return MagicMock() + + +@pytest.fixture +def bridge_client(mock_router: MagicMock) -> LiteLLMBridgeClient: + return LiteLLMBridgeClient(provider_name="stub-provider", router=mock_router) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py index 7ca75f9d6..564eca545 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py @@ -16,10 +16,18 @@ class StubHttpResponse: - def __init__(self, *, status_code: int, text: str = "", json_payload: dict[str, Any] | None = None) -> None: + def __init__( + self, + *, + status_code: int, + text: str = "", + json_payload: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ) -> None: self.status_code = status_code self.text = text self._json_payload = json_payload + self.headers = headers or {} def json(self) -> dict[str, Any]: if self._json_payload is None: @@ -52,58 +60,58 @@ def test_map_http_status_to_provider_error_kind( assert map_http_status_to_provider_error_kind(status_code=status_code, body_text=body_text) == expected_kind -def test_map_http_error_to_provider_error_uses_text_when_no_json() -> None: - response = StubHttpResponse(status_code=429, text="Rate limit hit") +@pytest.mark.parametrize( + "status_code,text,json_payload,expected_kind,expected_message", + [ + ( + 429, + "Rate limit hit", + None, + ProviderErrorKind.RATE_LIMIT, + "Rate limit hit", + ), + ( + 400, + '{"error": {"type": "invalid_request_error", "message": "Context too long."}}', + {"error": {"type": "invalid_request_error", "message": "Context too long."}}, + ProviderErrorKind.BAD_REQUEST, + "Context too long.", + ), + ( + 403, + "", + {"error": "Insufficient permissions for model"}, + ProviderErrorKind.PERMISSION_DENIED, + "Insufficient permissions for model", + ), + ( + 400, + "", + {"error": {"type": "invalid_request_error", "message": "The request payload is invalid."}}, + ProviderErrorKind.BAD_REQUEST, + "The request payload is invalid.", + ), + ], + ids=[ + "text-when-no-json", + "json-over-raw-text", + "json-when-text-missing", + "nested-error-message", + ], +) +def test_map_http_error_to_provider_error( + status_code: int, + text: str, + json_payload: dict[str, Any] | None, + expected_kind: ProviderErrorKind, + expected_message: str, +) -> None: + response = StubHttpResponse(status_code=status_code, text=text, json_payload=json_payload) error = map_http_error_to_provider_error(response=response, provider_name="stub-provider", model_name="stub-model") assert isinstance(error, ProviderError) - assert error.kind == ProviderErrorKind.RATE_LIMIT - assert error.message == "Rate limit hit" - assert error.status_code == 429 + assert error.kind == expected_kind + assert error.message == expected_message assert error.provider_name == "stub-provider" - assert error.model_name == "stub-model" - - -def test_map_http_error_to_provider_error_prefers_json_over_raw_text() -> None: - response = StubHttpResponse( - status_code=400, - text='{"error": {"type": "invalid_request_error", "message": "Context too long."}}', - json_payload={ - "error": { - "type": "invalid_request_error", - "message": "Context too long.", - } - }, - ) - error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") - assert error.kind == ProviderErrorKind.BAD_REQUEST - assert error.message == "Context too long." - - -def test_map_http_error_to_provider_error_uses_json_payload_when_text_missing() -> None: - response = StubHttpResponse( - status_code=403, - text="", - json_payload={"error": "Insufficient permissions for model"}, - ) - error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") - assert error.kind == ProviderErrorKind.PERMISSION_DENIED - assert error.message == "Insufficient permissions for model" - - -def test_map_http_error_to_provider_error_uses_nested_error_message_payload() -> None: - response = StubHttpResponse( - status_code=400, - text="", - json_payload={ - "error": { - "type": "invalid_request_error", - "message": "The request payload is invalid.", - } - }, - ) - error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") - assert error.kind == ProviderErrorKind.BAD_REQUEST - assert error.message == "The request payload is invalid." def test_provider_error_unsupported_capability_helper() -> None: @@ -145,3 +153,23 @@ def test_provider_error_is_catchable_as_exception() -> None: raise error assert exc_info.value.kind == ProviderErrorKind.AUTHENTICATION assert str(exc_info.value) == "Invalid API key" + + +def test_map_http_error_extracts_retry_after_on_429() -> None: + response = StubHttpResponse( + status_code=429, + text="Rate limit hit", + headers={"retry-after": "2.5"}, + ) + error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") + assert error.retry_after == 2.5 + + +def test_map_http_error_retry_after_is_none_for_non_429() -> None: + response = StubHttpResponse( + status_code=500, + text="Internal server error", + headers={"retry-after": "10"}, + ) + error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") + assert error.retry_after is None diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py index 24df83b56..7c9b8db9a 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py @@ -10,6 +10,7 @@ import pytest from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind from data_designer.engine.models.clients.types import ( ChatCompletionRequest, EmbeddingRequest, @@ -17,16 +18,17 @@ ) -def test_completion_maps_canonical_fields_from_litellm_response() -> None: +def test_completion_maps_canonical_fields_from_litellm_response( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: response = _build_chat_response( content="final answer", reasoning_content="reasoning trace", tool_calls=[{"id": "call-1", "function": {"name": "lookup", "arguments": '{"query":"foo"}'}}], usage=SimpleNamespace(prompt_tokens=11, completion_tokens=13, total_tokens=24), ) - router = MagicMock() - router.completion.return_value = response - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + mock_router.completion.return_value = response request = ChatCompletionRequest( model="stub-model", @@ -39,7 +41,7 @@ def test_completion_maps_canonical_fields_from_litellm_response() -> None: extra_headers={"x-trace": "1"}, metadata={"trace_id": "abc"}, ) - result = client.completion(request) + result = bridge_client.completion(request) assert result.message.content == "final answer" assert result.message.reasoning_content == "reasoning trace" @@ -53,7 +55,7 @@ def test_completion_maps_canonical_fields_from_litellm_response() -> None: assert result.usage.total_tokens == 24 assert result.raw is response - router.completion.assert_called_once_with( + mock_router.completion.assert_called_once_with( model="stub-model", messages=[{"role": "user", "content": "hello"}], tools=[{"type": "function", "function": {"name": "lookup"}}], @@ -67,41 +69,43 @@ def test_completion_maps_canonical_fields_from_litellm_response() -> None: @pytest.mark.asyncio -async def test_acompletion_maps_canonical_fields_from_litellm_response() -> None: +async def test_acompletion_maps_canonical_fields_from_litellm_response( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: response = _build_chat_response(content="async result", reasoning_content=None, tool_calls=[], usage=None) - router = MagicMock() - router.acompletion = AsyncMock(return_value=response) - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + mock_router.acompletion = AsyncMock(return_value=response) request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) - result = await client.acompletion(request) + result = await bridge_client.acompletion(request) assert result.message.content == "async result" assert result.usage is None - router.acompletion.assert_awaited_once_with( + mock_router.acompletion.assert_awaited_once_with( model="stub-model", messages=[{"role": "user", "content": "hello"}], ) -def test_embeddings_maps_vectors_and_usage() -> None: +def test_embeddings_maps_vectors_and_usage( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: response = SimpleNamespace( data=[{"embedding": [1, 2]}, SimpleNamespace(embedding=[3.5, 4.5])], usage=SimpleNamespace(prompt_tokens=4, total_tokens=4), ) - router = MagicMock() - router.embedding.return_value = response - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + mock_router.embedding.return_value = response request = EmbeddingRequest(model="stub-model", inputs=["a", "b"], dimensions=32, encoding_format="float") - result = client.embeddings(request) + result = bridge_client.embeddings(request) assert result.vectors == [[1.0, 2.0], [3.5, 4.5]] assert result.usage is not None assert result.usage.input_tokens == 4 assert result.usage.output_tokens is None assert result.raw is response - router.embedding.assert_called_once_with( + mock_router.embedding.assert_called_once_with( model="stub-model", input=["a", "b"], encoding_format="float", @@ -109,7 +113,19 @@ def test_embeddings_maps_vectors_and_usage() -> None: ) -def test_generate_image_uses_chat_completion_path_when_messages_are_present() -> None: +@pytest.mark.parametrize( + "messages", + [ + [{"role": "user", "content": "generate image"}], + [], + ], + ids=["with-content", "empty-list"], +) +def test_generate_image_uses_chat_completion_path_when_messages_provided( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, + messages: list[dict[str, Any]], +) -> None: response = _build_chat_response( content=None, reasoning_content=None, @@ -117,61 +133,30 @@ def test_generate_image_uses_chat_completion_path_when_messages_are_present() -> images=[{"image_url": {"url": "data:image/png;base64,aGVsbG8="}}], usage=None, ) - router = MagicMock() - router.completion.return_value = response - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + mock_router.completion.return_value = response request = ImageGenerationRequest( model="stub-model", prompt="unused because messages are supplied", - messages=[{"role": "user", "content": "generate image"}], + messages=messages, n=1, ) - result = client.generate_image(request) + result = bridge_client.generate_image(request) assert len(result.images) == 1 assert result.images[0].b64_data == "aGVsbG8=" - assert result.images[0].mime_type == "image/png" - assert result.usage is not None - assert result.usage.generated_images == 1 - router.completion.assert_called_once_with( + mock_router.completion.assert_called_once_with( model="stub-model", - messages=[{"role": "user", "content": "generate image"}], + messages=messages, n=1, ) - router.image_generation.assert_not_called() + mock_router.image_generation.assert_not_called() -def test_generate_image_uses_chat_completion_path_when_messages_is_empty_list() -> None: - response = _build_chat_response( - content=None, - reasoning_content=None, - tool_calls=None, - images=[{"image_url": {"url": "data:image/png;base64,aGVsbG8="}}], - usage=None, - ) - router = MagicMock() - router.completion.return_value = response - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) - - request = ImageGenerationRequest( - model="stub-model", - prompt="unused because messages are supplied", - messages=[], - n=1, - ) - result = client.generate_image(request) - - assert len(result.images) == 1 - router.completion.assert_called_once_with( - model="stub-model", - messages=[], - n=1, - ) - router.image_generation.assert_not_called() - - -def test_generate_image_uses_diffusion_path_without_messages() -> None: +def test_generate_image_uses_diffusion_path_without_messages( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: response = SimpleNamespace( data=[ SimpleNamespace(b64_json="Zmlyc3Q="), @@ -179,12 +164,10 @@ def test_generate_image_uses_diffusion_path_without_messages() -> None: ], usage=SimpleNamespace(input_tokens=9, output_tokens=12), ) - router = MagicMock() - router.image_generation.return_value = response - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + mock_router.image_generation.return_value = response request = ImageGenerationRequest(model="stub-model", prompt="make an image", n=2) - result = client.generate_image(request) + result = bridge_client.generate_image(request) assert [image.b64_data for image in result.images] == ["Zmlyc3Q=", "c2Vjb25k"] assert [image.mime_type for image in result.images] == [None, "image/jpeg"] @@ -193,111 +176,235 @@ def test_generate_image_uses_diffusion_path_without_messages() -> None: assert result.usage.output_tokens == 12 assert result.usage.total_tokens == 21 assert result.usage.generated_images == 2 - router.image_generation.assert_called_once_with(prompt="make an image", model="stub-model", n=2) + mock_router.image_generation.assert_called_once_with(prompt="make an image", model="stub-model", n=2) @pytest.mark.asyncio -async def test_aembeddings_maps_vectors_and_usage() -> None: +async def test_aembeddings_maps_vectors_and_usage( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: response = SimpleNamespace( data=[{"embedding": [0.1, 0.2]}, SimpleNamespace(embedding=[0.3, 0.4])], usage=SimpleNamespace(prompt_tokens=5, total_tokens=5), ) - router = MagicMock() - router.aembedding = AsyncMock(return_value=response) - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + mock_router.aembedding = AsyncMock(return_value=response) request = EmbeddingRequest(model="stub-model", inputs=["x", "y"]) - result = await client.aembeddings(request) + result = await bridge_client.aembeddings(request) assert result.vectors == [[0.1, 0.2], [0.3, 0.4]] assert result.usage is not None assert result.usage.input_tokens == 5 assert result.raw is response - router.aembedding.assert_awaited_once_with(model="stub-model", input=["x", "y"]) + mock_router.aembedding.assert_awaited_once_with(model="stub-model", input=["x", "y"]) -def test_completion_coerces_list_content_blocks_to_string() -> None: +def test_completion_coerces_list_content_blocks_to_string( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: response = _build_chat_response( content=[{"type": "text", "text": "first"}, {"type": "text", "text": "second"}], reasoning_content=None, tool_calls=[], usage=None, ) - router = MagicMock() - router.completion.return_value = response - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + mock_router.completion.return_value = response request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) - result = client.completion(request) + result = bridge_client.completion(request) assert result.message.content == "first\nsecond" -def test_close_and_aclose_are_callable() -> None: - router = MagicMock() - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) - client.close() +def test_close_and_aclose_are_callable(bridge_client: LiteLLMBridgeClient) -> None: + bridge_client.close() @pytest.mark.asyncio -async def test_aclose_is_callable() -> None: - router = MagicMock() - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) - await client.aclose() +async def test_aclose_is_callable(bridge_client: LiteLLMBridgeClient) -> None: + await bridge_client.aclose() @pytest.mark.asyncio -async def test_agenerate_image_uses_diffusion_path_without_messages() -> None: +async def test_agenerate_image_uses_diffusion_path_without_messages( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: response = SimpleNamespace( data=[SimpleNamespace(b64_json="YXN5bmM=")], usage=SimpleNamespace(input_tokens=3, output_tokens=7), ) - router = MagicMock() - router.aimage_generation = AsyncMock(return_value=response) - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + mock_router.aimage_generation = AsyncMock(return_value=response) request = ImageGenerationRequest(model="stub-model", prompt="async image", n=1) - result = await client.agenerate_image(request) + result = await bridge_client.agenerate_image(request) assert len(result.images) == 1 assert result.images[0].b64_data == "YXN5bmM=" assert result.usage is not None assert result.usage.generated_images == 1 - router.aimage_generation.assert_awaited_once_with(prompt="async image", model="stub-model", n=1) + mock_router.aimage_generation.assert_awaited_once_with(prompt="async image", model="stub-model", n=1) -def test_completion_with_empty_choices_returns_empty_message() -> None: +def test_completion_with_empty_choices_returns_empty_message( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: response = SimpleNamespace(choices=[], usage=None) - router = MagicMock() - router.completion.return_value = response - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + mock_router.completion.return_value = response request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) - result = client.completion(request) + result = bridge_client.completion(request) assert result.message.content is None assert result.message.tool_calls == [] assert result.message.images == [] -def test_completion_with_tool_call_dict_arguments() -> None: +def test_completion_with_tool_call_dict_arguments( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: response = _build_chat_response( content=None, reasoning_content=None, tool_calls=[{"id": "call-2", "function": {"name": "search", "arguments": {"q": "test"}}}], usage=None, ) - router = MagicMock() - router.completion.return_value = response - client = LiteLLMBridgeClient(provider_name="stub-provider", router=router) + mock_router.completion.return_value = response request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) - result = client.completion(request) + result = bridge_client.completion(request) assert len(result.message.tool_calls) == 1 assert result.message.tool_calls[0].arguments_json == '{"q": "test"}' +# --- Exception wrapping tests --- + + +def test_completion_wraps_router_exception_with_status_code( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: + exc = Exception("Rate limit exceeded") + exc.status_code = 429 # type: ignore[attr-defined] + mock_router.completion.side_effect = exc + + request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) + with pytest.raises(ProviderError) as exc_info: + bridge_client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.RATE_LIMIT + assert exc_info.value.status_code == 429 + assert exc_info.value.provider_name == "stub-provider" + assert exc_info.value.cause is exc + + +def test_completion_wraps_generic_router_exception( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: + mock_router.completion.side_effect = RuntimeError("something broke") + + request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) + with pytest.raises(ProviderError) as exc_info: + bridge_client.completion(request) + + assert exc_info.value.kind == ProviderErrorKind.API_ERROR + assert "something broke" in exc_info.value.message + assert exc_info.value.status_code is None + + +def test_completion_passes_through_provider_error( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: + original = ProviderError(kind=ProviderErrorKind.AUTHENTICATION, message="bad key") + mock_router.completion.side_effect = original + + request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) + with pytest.raises(ProviderError) as exc_info: + bridge_client.completion(request) + + assert exc_info.value is original + + +@pytest.mark.asyncio +async def test_acompletion_wraps_router_exception( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: + mock_router.acompletion = AsyncMock(side_effect=ConnectionError("connection refused")) + + request = ChatCompletionRequest(model="stub-model", messages=[{"role": "user", "content": "hello"}]) + with pytest.raises(ProviderError) as exc_info: + await bridge_client.acompletion(request) + + assert exc_info.value.kind == ProviderErrorKind.API_CONNECTION + + +def test_embeddings_wraps_router_exception( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: + exc = Exception("server error") + exc.status_code = 500 # type: ignore[attr-defined] + mock_router.embedding.side_effect = exc + + request = EmbeddingRequest(model="stub-model", inputs=["a"]) + with pytest.raises(ProviderError) as exc_info: + bridge_client.embeddings(request) + + assert exc_info.value.kind == ProviderErrorKind.INTERNAL_SERVER + + +def test_generate_image_wraps_router_exception( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: + mock_router.image_generation.side_effect = TimeoutError("timed out") + + request = ImageGenerationRequest(model="stub-model", prompt="make an image") + with pytest.raises(ProviderError) as exc_info: + bridge_client.generate_image(request) + + assert exc_info.value.kind == ProviderErrorKind.TIMEOUT + + +@pytest.mark.asyncio +async def test_agenerate_image_wraps_router_exception( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: + mock_router.aimage_generation = AsyncMock(side_effect=RuntimeError("boom")) + + request = ImageGenerationRequest(model="stub-model", prompt="async image") + with pytest.raises(ProviderError) as exc_info: + await bridge_client.agenerate_image(request) + + assert exc_info.value.kind == ProviderErrorKind.API_ERROR + + +@pytest.mark.asyncio +async def test_aembeddings_wraps_router_exception( + mock_router: MagicMock, + bridge_client: LiteLLMBridgeClient, +) -> None: + mock_router.aembedding = AsyncMock(side_effect=RuntimeError("network error")) + + request = EmbeddingRequest(model="stub-model", inputs=["a"]) + with pytest.raises(ProviderError) as exc_info: + await bridge_client.aembeddings(request) + + assert exc_info.value.kind == ProviderErrorKind.API_ERROR + + +# --- Helpers --- + + def _build_chat_response( *, content: Any, From ba22397bd3030dc372a273cccb43c4d4a279e961 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 4 Mar 2026 10:07:23 -0700 Subject: [PATCH 15/27] Use contextlib to dry out some code --- .../models/clients/adapters/litellm_bridge.py | 82 ++++++++----------- 1 file changed, 34 insertions(+), 48 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index 09194aa4b..38b414813 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -3,11 +3,17 @@ from __future__ import annotations +import contextlib import logging +from collections.abc import Iterator from typing import Any, Protocol from data_designer.engine.models.clients.base import ModelClient -from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind +from data_designer.engine.models.clients.errors import ( + ProviderError, + ProviderErrorKind, + map_http_status_to_provider_error_kind, +) from data_designer.engine.models.clients.parsing import ( aextract_images_from_chat_response, aextract_images_from_image_response, @@ -68,62 +74,46 @@ def supports_image_generation(self) -> bool: return True def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - try: + with _handle_non_provider_errors(self.provider_name): response = self._router.completion( model=request.model, messages=request.messages, **collect_non_none_optional_fields(request), ) - except ProviderError: - raise - except Exception as exc: - raise _wrap_router_error(exc, provider_name=self.provider_name) from exc return parse_chat_completion_response(response) async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - try: + with _handle_non_provider_errors(self.provider_name): response = await self._router.acompletion( model=request.model, messages=request.messages, **collect_non_none_optional_fields(request), ) - except ProviderError: - raise - except Exception as exc: - raise _wrap_router_error(exc, provider_name=self.provider_name) from exc return parse_chat_completion_response(response) def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - try: + with _handle_non_provider_errors(self.provider_name): response = self._router.embedding( model=request.model, input=request.inputs, **collect_non_none_optional_fields(request), ) - except ProviderError: - raise - except Exception as exc: - raise _wrap_router_error(exc, provider_name=self.provider_name) from exc vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])] return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response) async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - try: + with _handle_non_provider_errors(self.provider_name): response = await self._router.aembedding( model=request.model, input=request.inputs, **collect_non_none_optional_fields(request), ) - except ProviderError: - raise - except Exception as exc: - raise _wrap_router_error(exc, provider_name=self.provider_name) from exc vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])] return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response) def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) - try: + with _handle_non_provider_errors(self.provider_name): if request.messages is not None: response = self._router.completion( model=request.model, @@ -138,17 +128,13 @@ def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResp **image_kwargs, ) images = extract_images_from_image_response(response) - except ProviderError: - raise - except Exception as exc: - raise _wrap_router_error(exc, provider_name=self.provider_name) from exc usage = extract_usage(getattr(response, "usage", None), generated_images=len(images)) return ImageGenerationResponse(images=images, usage=usage, raw=response) async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) - try: + with _handle_non_provider_errors(self.provider_name): if request.messages is not None: response = await self._router.acompletion( model=request.model, @@ -163,10 +149,6 @@ async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerat **image_kwargs, ) images = await aextract_images_from_image_response(response) - except ProviderError: - raise - except Exception as exc: - raise _wrap_router_error(exc, provider_name=self.provider_name) from exc usage = extract_usage(getattr(response, "usage", None), generated_images=len(images)) return ImageGenerationResponse(images=images, usage=usage, raw=response) @@ -178,23 +160,27 @@ async def aclose(self) -> None: return None -def _wrap_router_error(exc: Exception, *, provider_name: str) -> ProviderError: - """Normalize a raw router/LiteLLM exception into a canonical ProviderError.""" - status_code = getattr(exc, "status_code", None) - if isinstance(status_code, int): - from data_designer.engine.models.clients.errors import map_http_status_to_provider_error_kind - - kind = map_http_status_to_provider_error_kind(status_code=status_code, body_text=str(exc)) - else: - kind = _infer_error_kind(exc) - - return ProviderError( - kind=kind, - message=str(exc), - status_code=status_code if isinstance(status_code, int) else None, - provider_name=provider_name, - cause=exc, - ) +@contextlib.contextmanager +def _handle_non_provider_errors(provider_name: str) -> Iterator[None]: + """Catch non-ProviderError exceptions from the router and re-raise as ProviderError.""" + try: + yield + except ProviderError: + raise + except Exception as exc: + status_code = getattr(exc, "status_code", None) + if isinstance(status_code, int): + kind = map_http_status_to_provider_error_kind(status_code=status_code, body_text=str(exc)) + else: + kind = _infer_error_kind(exc) + + raise ProviderError( + kind=kind, + message=str(exc), + status_code=status_code if isinstance(status_code, int) else None, + provider_name=provider_name, + cause=exc, + ) from exc def _infer_error_kind(exc: Exception) -> ProviderErrorKind: From aeac3b9b0d0e76a801a94a51502a6f689c433c71 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 4 Mar 2026 10:16:00 -0700 Subject: [PATCH 16/27] Address Greptile feedback: HTTP-date retry-after parsing, docstring clarity - Parse RFC 7231 HTTP-date strings in Retry-After header (used by Azure and Anthropic during rate-limiting) in addition to numeric delay-seconds - Clarify collect_non_none_optional_fields docstring explaining why f.default is None is the correct check for optional field forwarding - Add tests for HTTP-date and garbage Retry-After values --- .../engine/models/clients/errors.py | 15 +++++++++++++ .../engine/models/clients/parsing.py | 7 ++++++- .../models/clients/test_client_errors.py | 21 +++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index 402fbd742..15108734e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -3,6 +3,9 @@ from __future__ import annotations +import calendar +import email.utils +import time from dataclasses import dataclass from enum import Enum @@ -148,6 +151,7 @@ def _extract_structured_message(response: HttpResponse) -> str: def _extract_retry_after(response: HttpResponse) -> float | None: + """Parse Retry-After header value (delay-seconds or HTTP-date per RFC 7231).""" headers = getattr(response, "headers", None) if headers is None: return None @@ -161,7 +165,18 @@ def _extract_retry_after(response: HttpResponse) -> float | None: try: return float(raw) except (ValueError, TypeError): + pass + return _parse_http_date_as_delay(raw) + + +def _parse_http_date_as_delay(value: str) -> float | None: + """Convert an HTTP-date Retry-After value to seconds from now.""" + parsed = email.utils.parsedate(value) + if parsed is None: return None + target = calendar.timegm(parsed) + delay = target - time.time() + return max(delay, 0.0) def _looks_like_context_window_error(text: str) -> bool: diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py index 561240300..b3123e091 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -310,7 +310,12 @@ def get_first_value_or_none(values: Any) -> Any | None: def collect_non_none_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: - """Extract non-None optional fields from a request dataclass, skipping *exclude*.""" + """Extract non-None optional fields from a request dataclass, skipping *exclude*. + + The ``f.default is None`` check intentionally targets fields whose default is + ``None`` — i.e. truly optional kwargs the caller may or may not set. Fields with + non-``None`` defaults are not "optional" in this forwarding sense and are excluded. + """ return { f.name: v for f in dataclasses.fields(request) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py index 564eca545..31039120d 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py @@ -173,3 +173,24 @@ def test_map_http_error_retry_after_is_none_for_non_429() -> None: ) error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") assert error.retry_after is None + + +def test_map_http_error_extracts_retry_after_from_http_date() -> None: + response = StubHttpResponse( + status_code=429, + text="Rate limit hit", + headers={"retry-after": "Fri, 31 Dec 2027 23:59:59 GMT"}, + ) + error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") + assert error.retry_after is not None + assert error.retry_after > 0 + + +def test_map_http_error_retry_after_returns_none_for_garbage() -> None: + response = StubHttpResponse( + status_code=429, + text="Rate limit hit", + headers={"retry-after": "not-a-date-or-number"}, + ) + error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") + assert error.retry_after is None From 55f3c960614170d8fc542cac5cb561f72a29128d Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 4 Mar 2026 13:27:24 -0700 Subject: [PATCH 17/27] Address Greptile feedback: FastAPI detail parsing, comment fixes - Fix misleading comment about prompt field defaults in _IMAGE_EXCLUDE - Handle list-format detail arrays in _extract_structured_message for FastAPI/Pydantic validation errors - Document scope boundary for vision content in collect_raw_image_candidates --- .../models/clients/adapters/litellm_bridge.py | 6 +++--- .../data_designer/engine/models/clients/errors.py | 6 ++++++ .../data_designer/engine/models/clients/parsing.py | 7 ++++++- .../engine/models/clients/test_client_errors.py | 13 +++++++++++++ 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index 38b414813..45d615031 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -55,9 +55,9 @@ async def aimage_generation(self, *, prompt: str, model: str, **kwargs: Any) -> class LiteLLMBridgeClient(ModelClient): """Bridge adapter that wraps the existing LiteLLM router behind canonical client types.""" - # "messages" and "prompt" have None defaults but are passed explicitly to choose - # between the chat-completion and diffusion code paths, so exclude them from the - # automatic optional-field forwarding. + # "messages" (optional, default None) and "prompt" (required) are passed explicitly + # to choose between the chat-completion and diffusion code paths, so exclude them + # from the automatic optional-field forwarding. _IMAGE_EXCLUDE = frozenset({"messages", "prompt"}) def __init__(self, *, provider_name: str, router: LiteLLMRouter) -> None: diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index 15108734e..8e8a3b0ac 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -147,6 +147,12 @@ def _extract_structured_message(response: HttpResponse) -> str: nested_message = value.get("message") if isinstance(nested_message, str) and nested_message.strip(): return nested_message.strip() + if isinstance(value, list): + parts = [ + item.get("msg") for item in value if isinstance(item, dict) and isinstance(item.get("msg"), str) + ] + if parts: + return "; ".join(parts) return "" diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py index b3123e091..315d5d4e7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -85,7 +85,12 @@ async def aextract_images_from_image_response(response: Any) -> list[ImagePayloa def collect_raw_image_candidates(message: Any) -> tuple[list[Any], list[Any]]: - """Return (primary, fallback) raw image candidates from a message.""" + """Return (primary, fallback) raw image candidates from a message. + + Only string content is used as a fallback source. List-format content blocks + (e.g. OpenAI multimodal ``image_url`` items) are not extracted here; that + parsing is deferred to adapter-specific logic in future PRs. + """ primary: list[Any] = [] raw_images = get_value_from(message, "images") if isinstance(raw_images, list): diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py index 31039120d..194884f13 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py @@ -91,12 +91,25 @@ def test_map_http_status_to_provider_error_kind( ProviderErrorKind.BAD_REQUEST, "The request payload is invalid.", ), + ( + 422, + "", + { + "detail": [ + {"loc": ["body", "name"], "msg": "field required"}, + {"loc": ["body", "age"], "msg": "not a valid integer"}, + ] + }, + ProviderErrorKind.UNPROCESSABLE_ENTITY, + "field required; not a valid integer", + ), ], ids=[ "text-when-no-json", "json-over-raw-text", "json-when-text-missing", "nested-error-message", + "fastapi-list-detail", ], ) def test_map_http_error_to_provider_error( From 828cc49b428610f17131ee677b7aad4f79c69e85 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Wed, 4 Mar 2026 15:02:22 -0700 Subject: [PATCH 18/27] add PR-2 architecture notes for model facade overhaul --- ...facade-overhaul-pr-2-architecture-notes.md | 140 ++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 plans/343/model-facade-overhaul-pr-2-architecture-notes.md diff --git a/plans/343/model-facade-overhaul-pr-2-architecture-notes.md b/plans/343/model-facade-overhaul-pr-2-architecture-notes.md new file mode 100644 index 000000000..5a07a9549 --- /dev/null +++ b/plans/343/model-facade-overhaul-pr-2-architecture-notes.md @@ -0,0 +1,140 @@ +--- +date: 2026-03-04 +authors: + - nmulepati +--- + +# Model Facade Overhaul PR-2 Architecture Notes + +This document captures the architecture intent for PR-2 from +`plans/343/model-facade-overhaul-plan-step-1.md`. + +## Goal + +Switch `ModelFacade` from direct LiteLLM router usage to the `ModelClient` protocol +introduced in PR-1. After this PR, `ModelFacade` consumes only canonical types +(`ChatCompletionResponse`, `EmbeddingResponse`, `ImageGenerationResponse`) and has +no direct import or runtime dependency on LiteLLM response shapes. + +## What Changes + +### 1. ModelFacade internals rewired to ModelClient + +`ModelFacade.__init__` currently constructs a `CustomRouter` and calls it directly: + +```python +self._router = CustomRouter([self._litellm_deployment], ...) +# ... +response = self._router.completion(model=..., messages=..., **kwargs) +``` + +After PR-2, it receives a `ModelClient` (selected by factory) and builds canonical requests: + +```python +self._client: ModelClient # injected via factory +# ... +request = ChatCompletionRequest(model=..., messages=..., **consolidated) +response: ChatCompletionResponse = self._client.completion(request) +``` + +The same pattern applies to embeddings (`EmbeddingRequest` → `EmbeddingResponse`) and +image generation (`ImageGenerationRequest` → `ImageGenerationResponse`). + +### 2. Client factory + +New file: `clients/factory.py` + +Responsible for selecting the right `ModelClient` adapter given a `ModelConfig` and +provider context. For PR-2, the only adapter is `LiteLLMBridgeClient`. The factory +encapsulates router construction and deployment config that currently lives in +`ModelFacade._get_litellm_deployment`. + +`models/factory.py` (`create_model_registry`) is updated to use the client factory +when constructing each `ModelFacade`. + +### 3. MCP compatibility update + +`MCPFacade` methods (`has_tool_calls`, `tool_call_count`, `process_completion_response`, +`refuse_completion_response`) currently accept `Any` and traverse +`completion_response.choices[0].message` with `getattr` for LiteLLM shapes. + +PR-2 updates these to accept `ChatCompletionResponse` and read from canonical fields: + +- `response.message.tool_calls` → `list[ToolCall]` (id, name, arguments_json) +- `response.message.content` → `str | None` +- `response.message.reasoning_content` → `str | None` + +`_extract_tool_calls` and `_normalize_tool_call` simplify significantly because +canonical `ToolCall` is already normalized (no nested `function` key, no dict vs +object polymorphism). + +### 4. Usage tracking consolidation + +The three existing methods: + +- `_track_token_usage_from_completion` +- `_track_token_usage_from_embedding` +- `_track_token_usage_from_image_diffusion` + +All read from provider-specific usage shapes (`litellm.types.utils.*`). PR-2 replaces +them with a single helper that reads from canonical `Usage`: + +```python +def _track_usage(self, usage: Usage | None, *, is_request_successful: bool) -> None +``` + +### 5. Image extraction moves into adapter + +`ModelFacade` currently does image extraction from raw LiteLLM responses +(`_try_extract_base64`, `_generate_image_chat_completion`, `_generate_image_diffusion`). + +After PR-2, the adapter returns `ImageGenerationResponse.images: list[ImagePayload]` +with `b64_data` already resolved. `ModelFacade.generate_image` / `agenerate_image` +simply reads `response.images` and extracts `b64_data` values — no more format +detection, URL downloading, or data URI parsing at the facade level. + +### 6. LiteLLM type removal from facade + +After PR-2, `facade.py` no longer imports: + +- `litellm` (the module, currently used for type hints) +- `CustomRouter`, `LiteLLMRouterDefaultKwargs` +- `litellm.types.utils.ModelResponse`, `EmbeddingResponse`, `ImageResponse`, `ImageUsage` + +These remain internal to `LiteLLMBridgeClient` and `models/factory.py`. + +### 7. Adapter lifecycle wiring + +`ModelClient.close()` / `aclose()` are wired through `ModelRegistry` so adapter +resources (HTTP clients, connection pools) are torn down deterministically when +generation is complete. + +- `ModelRegistry` gains `close()` / `aclose()` that iterate owned facades. +- `ModelFacade` gains `close()` / `aclose()` that delegate to `self._client`. +- `ResourceProvider` (or equivalent teardown hook) calls `ModelRegistry.close()`. + +## What Does NOT Change + +1. `ModelFacade` public method signatures — callers see the same API. +2. MCP tool-loop behavior — tool turns, refusal, parallel execution all preserved. +3. Usage accounting semantics — token, request, image, and tool usage remain identical. +4. Error boundaries — `@catch_llm_exceptions` / `@acatch_llm_exceptions` decorators + and `DataDesignerError` subclass hierarchy remain stable. +5. `consolidate_kwargs` merge semantics for `extra_body` / `extra_headers`. +6. `generate` / `agenerate` parser correction/restart loop logic. + +## Files Touched + +| File | Change | +|---|---| +| `models/facade.py` | Rewire to `ModelClient`, canonical types, consolidated usage tracking | +| `models/factory.py` | Use client factory to inject `ModelClient` into `ModelFacade` | +| `models/registry.py` | Add `close` / `aclose` lifecycle methods | +| `clients/factory.py` | New — adapter selection by provider config | +| `mcp/facade.py` | Accept `ChatCompletionResponse` instead of raw LiteLLM response | + +## Planned Follow-On + +PR-3 introduces the OpenAI-compatible native adapter with shared retry/throttle +infrastructure. At that point, the client factory gains a second adapter option +alongside the LiteLLM bridge. From 89a6d4e90725bbbfce1d75da0c1ce585f80216eb Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 5 Mar 2026 09:53:36 -0700 Subject: [PATCH 19/27] save progress on pr2 --- .gitignore | 3 + .../src/data_designer/engine/mcp/facade.py | 272 ++----- .../engine/models/clients/__init__.py | 4 +- .../engine/models/clients/factory.py | 50 ++ .../src/data_designer/engine/models/errors.py | 109 ++- .../src/data_designer/engine/models/facade.py | 521 ++++--------- .../data_designer/engine/models/factory.py | 10 +- .../data_designer/engine/models/registry.py | 22 +- .../data_designer/engine/testing/__init__.py | 2 + .../data_designer/engine/testing/fixtures.py | 70 +- .../src/data_designer/engine/testing/stubs.py | 95 ++- .../tests/engine/mcp/test_mcp_facade.py | 232 ++---- .../tests/engine/models/conftest.py | 14 +- .../tests/engine/models/test_facade.py | 729 ++++++------------ 14 files changed, 798 insertions(+), 1335 deletions(-) create mode 100644 packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py diff --git a/.gitignore b/.gitignore index 99f6e26ce..02a115b99 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,6 @@ packages/data-designer/README.md .cursor/rules/cerebro.mdc .cursor/mcp.json .claude/rules/cerebro.md + +# Claude worktrees +.claude/worktrees/ diff --git a/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py b/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py index f49069e18..89289ad5b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py @@ -4,7 +4,6 @@ from __future__ import annotations import json -import uuid from typing import Any from data_designer.config.mcp import MCPProviderT, ToolConfig @@ -12,6 +11,7 @@ from data_designer.engine.mcp.errors import DuplicateToolNameError, MCPConfigurationError, MCPToolError from data_designer.engine.mcp.registry import MCPToolDefinition from data_designer.engine.model_provider import MCPProviderRegistry +from data_designer.engine.models.clients.types import ChatCompletionResponse, ToolCall from data_designer.engine.models.utils import ChatMessage from data_designer.engine.secret_resolver import SecretResolver @@ -38,13 +38,6 @@ def __init__( secret_resolver: SecretResolver, mcp_provider_registry: MCPProviderRegistry, ) -> None: - """Initialize the MCPFacade. - - Args: - tool_config: The tool configuration this facade is scoped to. - secret_resolver: Resolver for secrets referenced in provider configs. - mcp_provider_registry: Registry of MCP provider configurations. - """ self._tool_config = tool_config self._secret_resolver = secret_resolver self._mcp_provider_registry = mcp_provider_registry @@ -79,39 +72,17 @@ def timeout_sec(self) -> float | None: return self._tool_config.timeout_sec @staticmethod - def tool_call_count(completion_response: Any) -> int: - """Count the number of tool calls in a completion response. - - Args: - completion_response: The LLM completion response (litellm.ModelResponse). - - Returns: - Number of tool calls in the response (0 if none). - """ - message = completion_response.choices[0].message - tool_calls = getattr(message, "tool_calls", None) - if tool_calls is None: - return 0 - return len(tool_calls) + def tool_call_count(completion_response: ChatCompletionResponse) -> int: + """Count the number of tool calls in a completion response.""" + return len(completion_response.message.tool_calls) @staticmethod - def has_tool_calls(completion_response: Any) -> bool: + def has_tool_calls(completion_response: ChatCompletionResponse) -> bool: """Returns True if tool calls are present in the completion response.""" - return MCPFacade.tool_call_count(completion_response) > 0 + return len(completion_response.message.tool_calls) > 0 def _resolve_provider(self, provider: MCPProviderT) -> MCPProviderT: - """Resolve secret references in an MCP provider's api_key. - - Creates a copy of the provider with the api_key resolved from any secret - reference (e.g., "env:API_KEY") to its actual value. - - Args: - provider: The MCP provider config. - - Returns: - A copy of the provider with resolved api_key, or the original provider - if no api_key is configured. - """ + """Resolve secret references in an MCP provider's api_key.""" api_key_ref = getattr(provider, "api_key", None) if not api_key_ref: return provider @@ -168,7 +139,7 @@ def get_tool_schemas(self) -> list[dict[str, Any]]: def process_completion_response( self, - completion_response: Any, + completion_response: ChatCompletionResponse, ) -> list[ChatMessage]: """Process an LLM completion response and execute any tool calls. @@ -178,10 +149,7 @@ def process_completion_response( tool calls), and returns the messages for continuing the conversation. Args: - completion_response: The completion response object from the LLM, - typically from `router.completion()`. Expected to have a - `choices[0].message` structure with optional `content`, - `reasoning_content`, and `tool_calls` attributes. + completion_response: The canonical ChatCompletionResponse from the model client. Returns: A list of ChatMessages to append to the conversation history: @@ -189,29 +157,23 @@ def process_completion_response( - If no tool calls: [assistant_message] Raises: - MCPToolError: If a tool call is missing a name. - MCPToolError: If tool call arguments cannot be parsed as JSON. - MCPToolError: If tool call arguments are an unsupported type. MCPToolError: If a requested tool is not in the allowed tools list. MCPToolError: If tool execution fails or times out. MCPConfigurationError: If a requested tool is not found on any configured provider. """ - message = completion_response.choices[0].message + message = completion_response.message - # Extract response content and reasoning content response_content = message.content or "" - reasoning_content = getattr(message, "reasoning_content", None) + reasoning_content = message.reasoning_content # Strip whitespace if reasoning is present (models often add extra newlines) if reasoning_content: response_content = response_content.strip() reasoning_content = reasoning_content.strip() - # Extract and normalize tool calls - tool_calls = self._extract_tool_calls(message) + tool_calls = message.tool_calls if not tool_calls: - # No tool calls - just return the assistant message return [ ChatMessage.as_assistant( content=response_content, @@ -220,49 +182,43 @@ def process_completion_response( ] # Has tool calls - execute and return all messages - assistant_message = self._build_assistant_tool_message(response_content, tool_calls, reasoning_content) - tool_messages = self._execute_tool_calls_internal(tool_calls) + tool_call_dicts = _canonical_tool_calls_to_dicts(tool_calls) + assistant_message = self._build_assistant_tool_message(response_content, tool_call_dicts, reasoning_content) + tool_messages = self._execute_tool_calls_from_canonical(tool_calls) return [assistant_message, *tool_messages] def refuse_completion_response( self, - completion_response: Any, + completion_response: ChatCompletionResponse, refusal_message: str | None = None, ) -> list[ChatMessage]: """Refuse tool calls without executing them. Used when the tool call turn budget is exhausted. Returns messages that include the assistant's tool call request but with refusal - responses instead of actual tool results. This allows the model - to gracefully degrade and provide a final response without tools. + responses instead of actual tool results. Args: - completion_response: The LLM completion response containing tool calls. - refusal_message: Optional custom refusal message. Defaults to a - standard message about tool budget exhaustion. + completion_response: The canonical ChatCompletionResponse containing tool calls. + refusal_message: Optional custom refusal message. Returns: - A list of ChatMessages to append to the conversation history: - - If tool calls were present: [assistant_message_with_tool_calls, *refusal_messages] - - If no tool calls: [assistant_message] + A list of ChatMessages to append to the conversation history. """ - message = completion_response.choices[0].message + message = completion_response.message - # Extract response content and reasoning content response_content = message.content or "" - reasoning_content = getattr(message, "reasoning_content", None) + reasoning_content = message.reasoning_content # Strip whitespace if reasoning is present (models often add extra newlines) if reasoning_content: response_content = response_content.strip() reasoning_content = reasoning_content.strip() - # Extract and normalize tool calls - tool_calls = self._extract_tool_calls(message) + tool_calls = message.tool_calls if not tool_calls: - # No tool calls to refuse - just return assistant message return [ ChatMessage.as_assistant( content=response_content, @@ -271,137 +227,22 @@ def refuse_completion_response( ] # Build assistant message with tool calls (same as normal) - assistant_message = self._build_assistant_tool_message(response_content, tool_calls, reasoning_content) + tool_call_dicts = _canonical_tool_calls_to_dicts(tool_calls) + assistant_message = self._build_assistant_tool_message(response_content, tool_call_dicts, reasoning_content) # Build refusal messages instead of executing tools refusal = refusal_message or DEFAULT_TOOL_REFUSAL_MESSAGE - tool_messages = [ChatMessage.as_tool(content=refusal, tool_call_id=tc["id"]) for tc in tool_calls] + tool_messages = [ChatMessage.as_tool(content=refusal, tool_call_id=tc.id) for tc in tool_calls] return [assistant_message, *tool_messages] - def _extract_tool_calls(self, message: Any) -> list[dict[str, Any]]: - """Extract and normalize tool calls from an LLM response message. - - Handles various LLM response formats (dict or object with attributes) - and normalizes them into a consistent dictionary format. Supports - parallel tool calling where the model returns multiple tool calls - in a single response. - - Args: - message: The LLM response message, either as a dictionary or an object - with a 'tool_calls' attribute. - - Returns: - A list of normalized tool call dictionaries. Each dictionary contains: - - 'id': Unique identifier for the tool call (generated if not provided) - - 'name': The name of the tool to call - - 'arguments': Parsed arguments as a dictionary - - 'arguments_json': Arguments serialized as a JSON string - - Returns an empty list if no tool calls are present in the message. - - Raises: - MCPToolError: If a tool call is missing a name. - MCPToolError: If tool call arguments cannot be parsed as JSON. - MCPToolError: If tool call arguments are an unsupported type. - """ - raw_tool_calls = getattr(message, "tool_calls", None) - if raw_tool_calls is None and isinstance(message, dict): - raw_tool_calls = message.get("tool_calls") - if not raw_tool_calls: - return [] - - tool_calls: list[dict[str, Any]] = [] - for raw_tool_call in raw_tool_calls: - tool_calls.append(self._normalize_tool_call(raw_tool_call)) - return tool_calls - - def _normalize_tool_call(self, raw_tool_call: Any) -> dict[str, Any]: - """Normalize a tool call from various LLM response formats. - - Handles both dictionary and object representations of tool calls, - supporting the OpenAI format (with nested 'function' key) and - flattened formats. - - Args: - raw_tool_call: A tool call in any supported format. - - Returns: - A normalized tool call dictionary with keys: - - 'id': Tool call identifier (UUID generated if not provided) - - 'name': The tool name - - 'arguments': Parsed arguments dictionary - - 'arguments_json': JSON string of arguments - - Raises: - MCPToolError: If the tool call is missing a name or has invalid - arguments that cannot be parsed as JSON. - """ - if isinstance(raw_tool_call, dict): - tool_call_id = raw_tool_call.get("id") - function = raw_tool_call.get("function") or {} - name = function.get("name") or raw_tool_call.get("name") - arguments = function.get("arguments") or raw_tool_call.get("arguments") - else: - tool_call_id = getattr(raw_tool_call, "id", None) - function = getattr(raw_tool_call, "function", None) - name = getattr(function, "name", None) if function is not None else getattr(raw_tool_call, "name", None) - arguments = ( - getattr(function, "arguments", None) - if function is not None - else getattr(raw_tool_call, "arguments", None) - ) - - if not name: - raise MCPToolError("MCP tool call is missing a tool name.") - - arguments_payload: dict[str, Any] - if arguments is None or arguments == "": - arguments_payload = {} - elif isinstance(arguments, str): - try: - arguments_payload = json.loads(arguments) - except json.JSONDecodeError as exc: - raise MCPToolError(f"Invalid tool arguments for '{name}': {arguments}") from exc - elif isinstance(arguments, dict): - arguments_payload = arguments - else: - raise MCPToolError(f"Unsupported tool arguments type for '{name}': {type(arguments)!r}") - - # Normalize arguments_json to ensure valid, canonical JSON - try: - arguments_json = json.dumps(arguments_payload) - except TypeError as exc: - raise MCPToolError(f"Non-serializable tool arguments for '{name}': {exc}") from exc - - return { - "id": tool_call_id or uuid.uuid4().hex, - "name": name, - "arguments": arguments_payload, - "arguments_json": arguments_json, - } - def _build_assistant_tool_message( self, response: str | None, tool_calls: list[dict[str, Any]], reasoning_content: str | None = None, ) -> ChatMessage: - """Build the assistant message containing tool call requests. - - Constructs a message in the format expected by the LLM conversation - history, representing the assistant's request to call tools. - - Args: - response: The assistant's text response content. May be empty if - the assistant only requested tool calls without additional text. - tool_calls: List of normalized tool call dictionaries. - reasoning_content: Optional reasoning content from the assistant's - response. If provided, will be included under the 'reasoning_content' key. - - Returns: - A ChatMessage representing the assistant message with tool call requests. - """ + """Build the assistant message containing tool call requests.""" tool_calls_payload = [ { "id": tool_call["id"], @@ -416,38 +257,22 @@ def _build_assistant_tool_message( tool_calls=tool_calls_payload, ) - def _execute_tool_calls_internal( + def _execute_tool_calls_from_canonical( self, - tool_calls: list[dict[str, Any]], + tool_calls: list[ToolCall], ) -> list[ChatMessage]: - """Execute tool calls in parallel and return tool response messages. - - Validates all tool calls, then executes them concurrently via the io module - using call_tools_parallel. This leverages parallel tool calling when the - model returns multiple tool calls in a single response. - - Args: - tool_calls: List of normalized tool call dictionaries to execute. - - Returns: - A list of tool response messages, one per tool call. - - Raises: - MCPToolError: If a tool is not in the allowed tools list or if - the MCP provider returns an error. - """ + """Execute canonical ToolCall objects and return tool response messages.""" allowed_tools = set(self._tool_config.allow_tools) if self._tool_config.allow_tools else None - # Validate all tool calls and collect provider + args calls_to_execute: list[tuple[MCPProviderT, str, dict[str, Any], str]] = [] - for tool_call in tool_calls: - tool_name = tool_call["name"] - if allowed_tools is not None and tool_name not in allowed_tools: + for tc in tool_calls: + if allowed_tools is not None and tc.name not in allowed_tools: providers_str = ", ".join(repr(p) for p in self._tool_config.providers) - raise MCPToolError(f"Tool {tool_name!r} is not permitted for providers: {providers_str}.") + raise MCPToolError(f"Tool {tc.name!r} is not permitted for providers: {providers_str}.") - resolved_provider = self._find_resolved_provider_for_tool(tool_name) - calls_to_execute.append((resolved_provider, tool_name, tool_call["arguments"], tool_call["id"])) + arguments = json.loads(tc.arguments_json) if tc.arguments_json else {} + resolved_provider = self._find_resolved_provider_for_tool(tc.name) + calls_to_execute.append((resolved_provider, tc.name, arguments, tc.id)) # Execute all calls in parallel results = mcp_io.call_tools( @@ -455,24 +280,13 @@ def _execute_tool_calls_internal( timeout_sec=self._tool_config.timeout_sec, ) - # Build response messages return [ ChatMessage.as_tool(content=result.content, tool_call_id=call[3]) for result, call in zip(results, calls_to_execute) ] def _find_resolved_provider_for_tool(self, tool_name: str) -> MCPProviderT: - """Find the provider that has the given tool and return it with resolved api_key. - - Args: - tool_name: The name of the tool to find. - - Returns: - The provider object (with resolved api_key) that has the tool. - - Raises: - MCPConfigurationError: If no provider has the tool. - """ + """Find the provider that has the given tool and return it with resolved api_key.""" for provider_name in self._tool_config.providers: provider = self._mcp_provider_registry.get_provider(provider_name) resolved_provider = self._resolve_provider(provider) @@ -483,3 +297,15 @@ def _find_resolved_provider_for_tool(self, tool_name: str) -> MCPProviderT: return resolved_provider raise MCPConfigurationError(f"Tool {tool_name!r} not found on any configured provider.") + + +def _canonical_tool_calls_to_dicts(tool_calls: list[ToolCall]) -> list[dict[str, Any]]: + """Convert canonical ToolCall objects to the internal dict format for ChatMessage.""" + return [ + { + "id": tc.id, + "name": tc.name, + "arguments_json": tc.arguments_json, + } + for tc in tool_calls + ] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index dec52401a..99312e4b8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -8,6 +8,7 @@ map_http_error_to_provider_error, map_http_status_to_provider_error_kind, ) +from data_designer.engine.models.clients.factory import create_model_client from data_designer.engine.models.clients.types import ( AssistantMessage, ChatCompletionRequest, @@ -23,12 +24,12 @@ ) __all__ = [ - "HttpResponse", "AssistantMessage", "ChatCompletionRequest", "ChatCompletionResponse", "EmbeddingRequest", "EmbeddingResponse", + "HttpResponse", "ImageGenerationRequest", "ImageGenerationResponse", "ImagePayload", @@ -37,6 +38,7 @@ "ProviderErrorKind", "ToolCall", "Usage", + "create_model_client", "map_http_error_to_provider_error", "map_http_status_to_provider_error_kind", ] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py new file mode 100644 index 000000000..c7e32ebcd --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import data_designer.lazy_heavy_imports as lazy +from data_designer.config.models import ModelConfig +from data_designer.engine.model_provider import ModelProviderRegistry +from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient +from data_designer.engine.models.clients.base import ModelClient +from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs +from data_designer.engine.secret_resolver import SecretResolver + + +def create_model_client( + model_config: ModelConfig, + secret_resolver: SecretResolver, + model_provider_registry: ModelProviderRegistry, +) -> ModelClient: + """Create a ModelClient for the given model configuration. + + Resolves the provider, API key, and constructs a LiteLLM router wrapped in + a LiteLLMBridgeClient adapter. + + Args: + model_config: The model configuration to create a client for. + secret_resolver: Resolver for secrets referenced in provider configs. + model_provider_registry: Registry of model provider configurations. + + Returns: + A ModelClient instance ready for use. + """ + provider = model_provider_registry.get_provider(model_config.provider) + api_key = None + if provider.api_key: + api_key = secret_resolver.resolve(provider.api_key) + api_key = api_key or "not-used-but-required" + + litellm_params = lazy.litellm.LiteLLM_Params( + model=f"{provider.provider_type}/{model_config.model}", + api_base=provider.endpoint, + api_key=api_key, + max_parallel_requests=model_config.inference_parameters.max_parallel_requests, + ) + deployment = { + "model_name": model_config.model, + "litellm_params": litellm_params.model_dump(), + } + router = CustomRouter([deployment], **LiteLLMRouterDefaultKwargs().model_dump()) + return LiteLLMBridgeClient(provider_name=provider.name, router=router) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index e29c81325..7a98f40d7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -12,6 +12,7 @@ import data_designer.lazy_heavy_imports as lazy from data_designer.engine.errors import DataDesignerError +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind if TYPE_CHECKING: import litellm @@ -34,8 +35,7 @@ def get_exception_primary_cause(exception: BaseException) -> BaseException: """ if exception.__cause__ is None: return exception - else: - return get_exception_primary_cause(exception.__cause__) + return get_exception_primary_cause(exception.__cause__) class GenerationValidationFailureError(Exception): ... @@ -124,7 +124,13 @@ def handle_llm_exceptions( ) err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose) match exception: - # Common errors that can come from LiteLLM + # Canonical ProviderError from the client adapter layer + case ProviderError(kind=kind): + _raise_from_provider_error( + exception, kind, model_name, model_provider_name, purpose, authentication_error, err_msg_parser + ) + + # LiteLLM-specific errors (safety net during bridge period) case lazy.litellm.exceptions.APIError(): raise err_msg_parser.parse_api_error(exception, authentication_error) from None @@ -228,7 +234,7 @@ def catch_llm_exceptions(func: Callable) -> Callable: """ @wraps(func) - def wrapper(model_facade: Any, *args, **kwargs): + def wrapper(model_facade: Any, *args: Any, **kwargs: Any) -> Any: try: return func(model_facade, *args, **kwargs) except Exception as e: @@ -315,7 +321,7 @@ def parse_context_window_exceeded_error( ) def parse_api_error( - self, exception: litellm.exceptions.InternalServerError, auth_error_msg: FormattedLLMErrorMessage + self, exception: litellm.exceptions.APIError, auth_error_msg: FormattedLLMErrorMessage ) -> DataDesignerError: if "Error code: 403" in str(exception): return ModelAuthenticationError(auth_error_msg) @@ -326,3 +332,96 @@ def parse_api_error( solution=f"Try again in a few moments. Check with your model provider {self.model_provider_name!r} if the issue persists.", ) ) + + +def _raise_from_provider_error( + exception: ProviderError, + kind: ProviderErrorKind, + model_name: str, + model_provider_name: str, + purpose: str, + authentication_error: FormattedLLMErrorMessage, + err_msg_parser: DownstreamLLMExceptionMessageParser, +) -> None: + """Map a canonical ProviderError to the appropriate DataDesignerError subclass.""" + _KIND_MAP: dict[ProviderErrorKind, type[DataDesignerError]] = { + ProviderErrorKind.RATE_LIMIT: ModelRateLimitError, + ProviderErrorKind.TIMEOUT: ModelTimeoutError, + ProviderErrorKind.NOT_FOUND: ModelNotFoundError, + ProviderErrorKind.PERMISSION_DENIED: ModelPermissionDeniedError, + ProviderErrorKind.UNSUPPORTED_PARAMS: ModelUnsupportedParamsError, + ProviderErrorKind.INTERNAL_SERVER: ModelInternalServerError, + ProviderErrorKind.UNPROCESSABLE_ENTITY: ModelUnprocessableEntityError, + ProviderErrorKind.API_CONNECTION: ModelAPIConnectionError, + } + + _MESSAGES: dict[ProviderErrorKind, tuple[str, str]] = { + ProviderErrorKind.RATE_LIMIT: ( + f"You have exceeded the rate limit for model {model_name!r} while {purpose}.", + "Wait and try again in a few moments.", + ), + ProviderErrorKind.TIMEOUT: ( + f"The request to model {model_name!r} timed out while {purpose}.", + "Check your connection and try again. You may need to increase the timeout setting for the model.", + ), + ProviderErrorKind.NOT_FOUND: ( + f"The specified model {model_name!r} could not be found while {purpose}.", + f"Check that the model name is correct and supported by your model provider {model_provider_name!r} and try again.", + ), + ProviderErrorKind.PERMISSION_DENIED: ( + f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.", + f"Use an API key that has the right permissions for the model or use a model the API key in use has access to in model provider {model_provider_name!r}.", + ), + ProviderErrorKind.UNSUPPORTED_PARAMS: ( + f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.", + f"Review the documentation for model provider {model_provider_name!r} and adjust your request.", + ), + ProviderErrorKind.INTERNAL_SERVER: ( + f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.", + f"Try again in a few moments. Check with your model provider {model_provider_name!r} if the issue persists.", + ), + ProviderErrorKind.UNPROCESSABLE_ENTITY: ( + f"The request to model {model_name!r} failed despite correct request format while {purpose}.", + "This is most likely temporary. Try again in a few moments.", + ), + ProviderErrorKind.API_CONNECTION: ( + f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.", + "Check your network/proxy/firewall settings.", + ), + } + + if kind == ProviderErrorKind.AUTHENTICATION: + raise ModelAuthenticationError(authentication_error) from None + + if kind == ProviderErrorKind.CONTEXT_WINDOW_EXCEEDED: + raise ModelContextWindowExceededError( + FormattedLLMErrorMessage( + cause=f"The input data for model '{model_name}' was found to exceed its supported context width while {purpose}.", + solution="Check the model's supported max context width. Adjust the length of your input along with completions and try again.", + ) + ) from None + + if kind == ProviderErrorKind.BAD_REQUEST: + err_msg = FormattedLLMErrorMessage( + cause=f"The request for model {model_name!r} was found to be malformed or missing required parameters while {purpose}.", + solution="Check your request parameters and try again.", + ) + if "is not a multimodal model" in str(exception): + err_msg = FormattedLLMErrorMessage( + cause=f"Model {model_name!r} is not a multimodal model, but it looks like you are trying to provide multimodal context while {purpose}.", + solution="Check your request parameters and try again.", + ) + raise ModelBadRequestError(err_msg) from None + + if kind in _KIND_MAP and kind in _MESSAGES: + error_cls = _KIND_MAP[kind] + cause_str, solution_str = _MESSAGES[kind] + raise error_cls(FormattedLLMErrorMessage(cause=cause_str, solution=solution_str)) from None + + # Fallback for API_ERROR and UNSUPPORTED_CAPABILITY + raise ModelAPIError( + FormattedLLMErrorMessage( + cause=f"An unexpected API error occurred with model {model_name!r} while {purpose}.", + solution=f"Try again in a few moments. Check with your model provider {model_provider_name!r} if the issue persists.", + ) + ) from None diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 13a0c1634..91545ecb2 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -9,16 +9,18 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any -import data_designer.lazy_heavy_imports as lazy from data_designer.config.models import GenerationType, ModelConfig, ModelProvider -from data_designer.config.utils.image_helpers import ( - extract_base64_from_data_uri, - is_base64_image, - is_image_diffusion_model, - load_image_url_to_base64, -) +from data_designer.config.utils.image_helpers import is_image_diffusion_model from data_designer.engine.mcp.errors import MCPConfigurationError from data_designer.engine.model_provider import ModelProviderRegistry +from data_designer.engine.models.clients.types import ( + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ImageGenerationRequest, + Usage, +) from data_designer.engine.models.errors import ( GenerationValidationFailureError, ImageGenerationError, @@ -26,17 +28,14 @@ catch_llm_exceptions, get_exception_primary_cause, ) -from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs from data_designer.engine.models.parsers.errors import ParserException from data_designer.engine.models.usage import ImageUsageStats, ModelUsageStats, RequestUsageStats, TokenUsageStats from data_designer.engine.models.utils import ChatMessage, prompt_to_messages -from data_designer.engine.secret_resolver import SecretResolver if TYPE_CHECKING: - import litellm - from data_designer.engine.mcp.facade import MCPFacade from data_designer.engine.mcp.registry import MCPRegistry + from data_designer.engine.models.clients.base import ModelClient def _identity(x: Any) -> Any: @@ -44,50 +43,28 @@ def _identity(x: Any) -> Any: return x -def _try_extract_base64(source: str | litellm.types.utils.ImageObject) -> str | None: - """Try to extract base64 image data from a data URI string or image response object. - - Args: - source: Either a data URI string (e.g. "data:image/png;base64,...") - or a litellm ImageObject with b64_json/url attributes. - - Returns: - Base64-encoded image string, or None if extraction fails. - """ - try: - if isinstance(source, str): - return extract_base64_from_data_uri(source) - - if getattr(source, "b64_json", None): - return source.b64_json - - if getattr(source, "url", None): - return load_image_url_to_base64(source.url) - except Exception: - logger.debug(f"Failed to extract base64 from source of type {type(source).__name__}") - return None - - return None +logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) +# Known keyword arguments extracted into ChatCompletionRequest fields. +_COMPLETION_REQUEST_FIELDS = frozenset( + {"temperature", "top_p", "max_tokens", "timeout", "tools", "extra_body", "extra_headers"} +) class ModelFacade: def __init__( self, model_config: ModelConfig, - secret_resolver: SecretResolver, model_provider_registry: ModelProviderRegistry, *, + client: ModelClient, mcp_registry: MCPRegistry | None = None, ) -> None: self._model_config = model_config - self._secret_resolver = secret_resolver self._model_provider_registry = model_provider_registry + self._client = client self._mcp_registry = mcp_registry - self._litellm_deployment = self._get_litellm_deployment(model_config) - self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump()) self._usage_stats = ModelUsageStats() @property @@ -119,8 +96,8 @@ def usage_stats(self) -> ModelUsageStats: return self._usage_stats def completion( - self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs - ) -> litellm.ModelResponse: + self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any + ) -> ChatCompletionResponse: message_payloads = [message.to_dict() for message in messages] logger.debug( f"Prompting model {self.model_name!r}...", @@ -129,24 +106,23 @@ def completion( response = None kwargs = self.consolidate_kwargs(**kwargs) try: - response = self._router.completion(model=self.model_name, messages=message_payloads, **kwargs) + request = self._build_chat_completion_request(message_payloads, kwargs) + response = self._client.completion(request) logger.debug( f"Received completion from model {self.model_name!r}", extra={ "model": self.model_name, "response": response, - "text": response.choices[0].message.content, + "text": response.message.content, "usage": self._usage_stats.model_dump(), }, ) return response - except Exception as e: - raise e finally: if not skip_usage_tracking and response is not None: - self._track_token_usage_from_completion(response) + self._track_usage(response.usage, is_request_successful=True) - def consolidate_kwargs(self, **kwargs) -> dict[str, Any]: + def consolidate_kwargs(self, **kwargs: Any) -> dict[str, Any]: # Remove purpose from kwargs to avoid passing it to the model kwargs.pop("purpose", None) kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} @@ -169,7 +145,7 @@ def generate( max_conversation_restarts: int = 0, skip_usage_tracking: bool = False, purpose: str | None = None, - **kwargs, + **kwargs: Any, ) -> tuple[Any, list[ChatMessage]]: """Generate a parsed output with correction steps. @@ -266,8 +242,8 @@ def generate( continue # Back to top # No tool calls remaining to process - response = (completion_response.choices[0].message.content or "").strip() - reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None) + response = (completion_response.message.content or "").strip() + reasoning_trace = completion_response.message.reasoning_content messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) curr_num_correction_steps += 1 @@ -306,7 +282,7 @@ def generate( @catch_llm_exceptions def generate_text_embeddings( - self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs + self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs: Any ) -> list[list[float]]: logger.debug( f"Generating embeddings with model {self.model_name!r}...", @@ -316,26 +292,24 @@ def generate_text_embeddings( }, ) kwargs = self.consolidate_kwargs(**kwargs) - response = None + response: EmbeddingResponse | None = None try: - response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs) + request = self._build_embedding_request(input_texts, kwargs) + response = self._client.embeddings(request) logger.debug( f"Received embeddings from model {self.model_name!r}", extra={ "model": self.model_name, - "embedding_count": len(response.data) if response.data else 0, + "embedding_count": len(response.vectors), "usage": self._usage_stats.model_dump(), }, ) - if response.data and len(response.data) == len(input_texts): - return [data["embedding"] for data in response.data] - else: - raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") - except Exception as e: - raise e + if len(response.vectors) == len(input_texts): + return response.vectors + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.vectors)}") finally: if not skip_usage_tracking and response is not None: - self._track_token_usage_from_embedding(response) + self._track_usage(response.usage, is_request_successful=True) @catch_llm_exceptions def generate_image( @@ -343,13 +317,13 @@ def generate_image( prompt: str, multi_modal_context: list[dict[str, Any]] | None = None, skip_usage_tracking: bool = False, - **kwargs, + **kwargs: Any, ) -> list[str]: """Generate image(s) and return base64-encoded data. Automatically detects the appropriate API based on model name: - - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) → image_generation API - - All other models → chat/completions API (default) + - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) -> image_generation API + - All other models -> chat/completions API (default) Both paths return base64-encoded image data. If the API returns multiple images, all are returned in the list. @@ -372,15 +346,19 @@ def generate_image( extra={"model": self.model_name, "prompt": prompt}, ) - # Auto-detect API type based on model name - if is_image_diffusion_model(self.model_name): - images = self._generate_image_diffusion(prompt, skip_usage_tracking, **kwargs) - else: - images = self._generate_image_chat_completion(prompt, multi_modal_context, skip_usage_tracking, **kwargs) + kwargs = self.consolidate_kwargs(**kwargs) + request = self._build_image_generation_request(prompt, multi_modal_context, kwargs) + response = self._client.generate_image(request) + + images = [img.b64_data for img in response.images] + + if not images: + raise ImageGenerationError("No image data found in image generation response") # Track image usage - if not skip_usage_tracking and len(images) > 0: + if not skip_usage_tracking and images: self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) + self._track_usage(response.usage, is_request_successful=True) return images @@ -395,186 +373,87 @@ def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: except ValueError as exc: raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc - def _generate_image_chat_completion( - self, - prompt: str, - multi_modal_context: list[dict[str, Any]] | None = None, - skip_usage_tracking: bool = False, - **kwargs, - ) -> list[str]: - """Generate image(s) using autoregressive model via chat completions API. + def _build_chat_completion_request( + self, messages: list[dict[str, Any]], kwargs: dict[str, Any] + ) -> ChatCompletionRequest: + """Build a ChatCompletionRequest from message payloads and consolidated kwargs.""" + request_fields: dict[str, Any] = {"model": self.model_name, "messages": messages} + metadata: dict[str, Any] = {} - Args: - prompt: The prompt for image generation - multi_modal_context: Optional list of image contexts for multi-modal generation - skip_usage_tracking: Whether to skip usage tracking - **kwargs: Additional arguments to pass to the model - - Returns: - List of base64-encoded image strings - """ - messages = prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context) - - response = None - try: - response = self.completion( - messages=messages, - skip_usage_tracking=skip_usage_tracking, - **kwargs, - ) - - logger.debug( - f"Received image(s) from autoregressive model {self.model_name!r}", - extra={"model": self.model_name, "response": response}, - ) - - # Validate response structure - if not response.choices or len(response.choices) == 0: - raise ImageGenerationError("Image generation response missing choices") - - message = response.choices[0].message - images = [] - - # Extract base64 from images attribute (primary path) - if hasattr(message, "images") and message.images: - for image in message.images: - # Handle different response formats - if isinstance(image, dict) and "image_url" in image: - image_url = image["image_url"] - - if isinstance(image_url, dict) and "url" in image_url: - if (b64 := _try_extract_base64(image_url["url"])) is not None: - images.append(b64) - elif isinstance(image_url, str): - if (b64 := _try_extract_base64(image_url)) is not None: - images.append(b64) - # Fallback: treat as base64 string - elif isinstance(image, str): - if (b64 := _try_extract_base64(image)) is not None: - images.append(b64) - - # Fallback: check content field if it looks like image data - if not images: - content = message.content or "" - if content and (content.startswith("data:image/") or is_base64_image(content)): - if (b64 := _try_extract_base64(content)) is not None: - images.append(b64) - - if not images: - raise ImageGenerationError("No image data found in image generation response") - - return images - - except Exception: - raise - - def _generate_image_diffusion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: - """Generate image(s) using diffusion model via image_generation API. - - Always returns base64. If the API returns URLs instead of inline base64, - the images are downloaded and converted automatically. + for key, value in kwargs.items(): + if key in _COMPLETION_REQUEST_FIELDS: + request_fields[key] = value + else: + metadata[key] = value - Returns: - List of base64-encoded image strings - """ - kwargs = self.consolidate_kwargs(**kwargs) + if metadata: + request_fields["metadata"] = metadata - response = None + return ChatCompletionRequest(**request_fields) - try: - response = self._router.image_generation(prompt=prompt, model=self.model_name, **kwargs) + def _build_embedding_request(self, input_texts: list[str], kwargs: dict[str, Any]) -> EmbeddingRequest: + """Build an EmbeddingRequest from input texts and consolidated kwargs.""" + return EmbeddingRequest( + model=self.model_name, + inputs=input_texts, + timeout=kwargs.get("timeout"), + extra_body=kwargs.get("extra_body"), + extra_headers=kwargs.get("extra_headers"), + ) - logger.debug( - f"Received {len(response.data)} image(s) from diffusion model {self.model_name!r}", - extra={"model": self.model_name, "response": response}, + def _build_image_generation_request( + self, + prompt: str, + multi_modal_context: list[dict[str, Any]] | None, + kwargs: dict[str, Any], + ) -> ImageGenerationRequest: + """Build an ImageGenerationRequest, choosing chat-completion vs diffusion path.""" + is_diffusion = is_image_diffusion_model(self.model_name) + + if is_diffusion: + return ImageGenerationRequest( + model=self.model_name, + prompt=prompt, + n=kwargs.get("n"), + timeout=kwargs.get("timeout"), + extra_body=kwargs.get("extra_body"), + extra_headers=kwargs.get("extra_headers"), ) - # Validate response - if not response.data or len(response.data) == 0: - raise ImageGenerationError("Image generation returned no data") - - images = [b64 for img in response.data if (b64 := _try_extract_base64(img)) is not None] - - if not images: - raise ImageGenerationError("No image data could be extracted from response") - - return images - - except Exception: - raise - finally: - if not skip_usage_tracking and response is not None: - self._track_token_usage_from_image_diffusion(response) - - def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict: - provider = self._model_provider_registry.get_provider(model_config.provider) - api_key = None - if provider.api_key: - api_key = self._secret_resolver.resolve(provider.api_key) - api_key = api_key or "not-used-but-required" - - litellm_params = lazy.litellm.LiteLLM_Params( - model=f"{provider.provider_type}/{model_config.model}", - api_base=provider.endpoint, - api_key=api_key, - max_parallel_requests=model_config.inference_parameters.max_parallel_requests, + chat_messages = [ + m.to_dict() for m in prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context) + ] + return ImageGenerationRequest( + model=self.model_name, + prompt=prompt, + messages=chat_messages, + n=kwargs.get("n"), + timeout=kwargs.get("timeout"), + extra_body=kwargs.get("extra_body"), + extra_headers=kwargs.get("extra_headers"), ) - return { - "model_name": model_config.model, - "litellm_params": litellm_params.model_dump(), - } - def _track_token_usage_from_completion(self, response: litellm.types.utils.ModelResponse | None) -> None: - if response is None: + def _track_usage(self, usage: Usage | None, *, is_request_successful: bool) -> None: + """Unified usage tracking from canonical Usage type.""" + if not is_request_successful: self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) return - if ( - response.usage is not None - and response.usage.prompt_tokens is not None - and response.usage.completion_tokens is not None - ): - self._usage_stats.extend( - token_usage=TokenUsageStats( - input_tokens=response.usage.prompt_tokens, - output_tokens=response.usage.completion_tokens, - ), - request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), - ) - def _track_token_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None: - if response is None: - self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) - return - if response.usage is not None and response.usage.prompt_tokens is not None: - self._usage_stats.extend( - token_usage=TokenUsageStats( - input_tokens=response.usage.prompt_tokens, - output_tokens=0, - ), - request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + token_usage = None + if usage is not None and usage.input_tokens is not None: + token_usage = TokenUsageStats( + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens or 0, ) - def _track_token_usage_from_image_diffusion(self, response: litellm.types.utils.ImageResponse | None) -> None: - """Track token usage from image_generation API response.""" - if response is None: - self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) - return - - if response.usage is not None and isinstance(response.usage, lazy.litellm.types.utils.ImageUsage): - self._usage_stats.extend( - token_usage=TokenUsageStats( - input_tokens=response.usage.input_tokens, - output_tokens=response.usage.output_tokens, - ), - request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), - ) - else: - # Successful response but no token usage data (some providers don't report it) - self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=1, failed_requests=0)) + self._usage_stats.extend( + token_usage=token_usage, + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) async def acompletion( self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any - ) -> litellm.ModelResponse: + ) -> ChatCompletionResponse: message_payloads = [message.to_dict() for message in messages] logger.debug( f"Prompting model {self.model_name!r}...", @@ -583,22 +462,21 @@ async def acompletion( response = None kwargs = self.consolidate_kwargs(**kwargs) try: - response = await self._router.acompletion(model=self.model_name, messages=message_payloads, **kwargs) + request = self._build_chat_completion_request(message_payloads, kwargs) + response = await self._client.acompletion(request) logger.debug( f"Received completion from model {self.model_name!r}", extra={ "model": self.model_name, "response": response, - "text": response.choices[0].message.content, + "text": response.message.content, "usage": self._usage_stats.model_dump(), }, ) return response - except Exception as e: - raise e finally: if not skip_usage_tracking and response is not None: - self._track_token_usage_from_completion(response) + self._track_usage(response.usage, is_request_successful=True) @acatch_llm_exceptions async def agenerate_text_embeddings( @@ -612,26 +490,24 @@ async def agenerate_text_embeddings( }, ) kwargs = self.consolidate_kwargs(**kwargs) - response = None + response: EmbeddingResponse | None = None try: - response = await self._router.aembedding(model=self.model_name, input=input_texts, **kwargs) + request = self._build_embedding_request(input_texts, kwargs) + response = await self._client.aembeddings(request) logger.debug( f"Received embeddings from model {self.model_name!r}", extra={ "model": self.model_name, - "embedding_count": len(response.data) if response.data else 0, + "embedding_count": len(response.vectors), "usage": self._usage_stats.model_dump(), }, ) - if response.data and len(response.data) == len(input_texts): - return [data["embedding"] for data in response.data] - else: - raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") - except Exception as e: - raise e + if len(response.vectors) == len(input_texts): + return response.vectors + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.vectors)}") finally: if not skip_usage_tracking and response is not None: - self._track_token_usage_from_embedding(response) + self._track_usage(response.usage, is_request_successful=True) @acatch_llm_exceptions async def agenerate( @@ -693,8 +569,8 @@ async def agenerate( continue - response = (completion_response.choices[0].message.content or "").strip() - reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None) + response = (completion_response.message.content or "").strip() + reasoning_trace = completion_response.message.reasoning_content messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) curr_num_correction_steps += 1 @@ -741,8 +617,8 @@ async def agenerate_image( """Async version of generate_image. Generate image(s) and return base64-encoded data. Automatically detects the appropriate API based on model name: - - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) → image_generation API - - All other models → chat/completions API (default) + - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) -> image_generation API + - All other models -> chat/completions API (default) Both paths return base64-encoded image data. If the API returns multiple images, all are returned in the list. @@ -765,133 +641,26 @@ async def agenerate_image( extra={"model": self.model_name, "prompt": prompt}, ) - # Auto-detect API type based on model name - if is_image_diffusion_model(self.model_name): - images = await self._agenerate_image_diffusion(prompt, skip_usage_tracking, **kwargs) - else: - images = await self._agenerate_image_chat_completion( - prompt, multi_modal_context, skip_usage_tracking, **kwargs - ) - - # Track image usage - if not skip_usage_tracking and len(images) > 0: - self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) - - return images - - async def _agenerate_image_chat_completion( - self, - prompt: str, - multi_modal_context: list[dict[str, Any]] | None = None, - skip_usage_tracking: bool = False, - **kwargs: Any, - ) -> list[str]: - """Async version of _generate_image_chat_completion. - - Generate image(s) using autoregressive model via chat completions API. - - Args: - prompt: The prompt for image generation - multi_modal_context: Optional list of image contexts for multi-modal generation - skip_usage_tracking: Whether to skip usage tracking - **kwargs: Additional arguments to pass to the model - - Returns: - List of base64-encoded image strings - """ - messages = prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context) - - response = None - try: - response = await self.acompletion( - messages=messages, - skip_usage_tracking=skip_usage_tracking, - **kwargs, - ) - - logger.debug( - f"Received image(s) from autoregressive model {self.model_name!r}", - extra={"model": self.model_name, "response": response}, - ) - - # Validate response structure - if not response.choices or len(response.choices) == 0: - raise ImageGenerationError("Image generation response missing choices") - - message = response.choices[0].message - images = [] - - # Extract base64 from images attribute (primary path) - if hasattr(message, "images") and message.images: - for image in message.images: - # Handle different response formats - if isinstance(image, dict) and "image_url" in image: - image_url = image["image_url"] - - if isinstance(image_url, dict) and "url" in image_url: - if (b64 := _try_extract_base64(image_url["url"])) is not None: - images.append(b64) - elif isinstance(image_url, str): - if (b64 := _try_extract_base64(image_url)) is not None: - images.append(b64) - # Fallback: treat as base64 string - elif isinstance(image, str): - if (b64 := _try_extract_base64(image)) is not None: - images.append(b64) - - # Fallback: check content field if it looks like image data - if not images: - content = message.content or "" - if content and (content.startswith("data:image/") or is_base64_image(content)): - if (b64 := _try_extract_base64(content)) is not None: - images.append(b64) - - if not images: - raise ImageGenerationError("No image data found in image generation response") - - return images - - except Exception: - raise - - async def _agenerate_image_diffusion( - self, prompt: str, skip_usage_tracking: bool = False, **kwargs: Any - ) -> list[str]: - """Async version of _generate_image_diffusion. - - Generate image(s) using diffusion model via image_generation API. - - Always returns base64. If the API returns URLs instead of inline base64, - the images are downloaded and converted automatically. - - Returns: - List of base64-encoded image strings - """ kwargs = self.consolidate_kwargs(**kwargs) + request = self._build_image_generation_request(prompt, multi_modal_context, kwargs) + response = await self._client.agenerate_image(request) - response = None + images = [img.b64_data for img in response.images] - try: - response = await self._router.aimage_generation(prompt=prompt, model=self.model_name, **kwargs) + if not images: + raise ImageGenerationError("No image data found in image generation response") - logger.debug( - f"Received {len(response.data)} image(s) from diffusion model {self.model_name!r}", - extra={"model": self.model_name, "response": response}, - ) - - # Validate response - if not response.data or len(response.data) == 0: - raise ImageGenerationError("Image generation returned no data") - - images = [b64 for img in response.data if (b64 := _try_extract_base64(img)) is not None] + # Track image usage + if not skip_usage_tracking and images: + self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) + self._track_usage(response.usage, is_request_successful=True) - if not images: - raise ImageGenerationError("No image data could be extracted from response") + return images - return images + def close(self) -> None: + """Release resources held by the underlying client.""" + self._client.close() - except Exception: - raise - finally: - if not skip_usage_tracking and response is not None: - self._track_token_usage_from_image_diffusion(response) + async def aclose(self) -> None: + """Async release resources held by the underlying client.""" + await self._client.aclose() diff --git a/packages/data-designer-engine/src/data_designer/engine/models/factory.py b/packages/data-designer-engine/src/data_designer/engine/models/factory.py index fb3b2e1d4..a23c0dbcc 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/factory.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/factory.py @@ -37,17 +37,23 @@ def create_model_registry( Returns: A configured ModelRegistry instance. """ + from data_designer.engine.models.clients.factory import create_model_client from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.litellm_overrides import apply_litellm_patches from data_designer.engine.models.registry import ModelRegistry apply_litellm_patches() - def model_facade_factory(model_config, secret_resolver, model_provider_registry): + def model_facade_factory( + model_config: ModelConfig, + secret_resolver: SecretResolver, + model_provider_registry: ModelProviderRegistry, + ) -> ModelFacade: + client = create_model_client(model_config, secret_resolver, model_provider_registry) return ModelFacade( model_config, - secret_resolver, model_provider_registry, + client=client, mcp_registry=mcp_registry, ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/registry.py b/packages/data-designer-engine/src/data_designer/engine/models/registry.py index 0b103e76b..b4dff0301 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/registry.py @@ -5,7 +5,7 @@ import logging from collections.abc import Callable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from data_designer.config.models import GenerationType, ModelConfig from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry @@ -27,7 +27,7 @@ def __init__( model_provider_registry: ModelProviderRegistry, model_configs: list[ModelConfig] | None = None, model_facade_factory: Callable[[ModelConfig, SecretResolver, ModelProviderRegistry], ModelFacade] | None = None, - ): + ) -> None: self._secret_resolver = secret_resolver self._model_provider_registry = model_provider_registry self._model_facade_factory = model_facade_factory @@ -69,7 +69,7 @@ def get_model_config(self, *, model_alias: str) -> ModelConfig: raise ValueError(f"No model config with alias {model_alias!r} found!") return self._model_configs[model_alias] - def get_model_usage_stats(self, total_time_elapsed: float) -> dict[str, dict]: + def get_model_usage_stats(self, total_time_elapsed: float) -> dict[str, dict[str, Any]]: return { model.model_name: model.usage_stats.get_usage_stats(total_time_elapsed=total_time_elapsed) for model in self._models.values() @@ -200,10 +200,18 @@ def run_health_check(self, model_aliases: list[str]) -> None: logger.error(f"{LOG_INDENT}❌ Failed!") raise e - def _set_model_configs(self, model_configs: list[ModelConfig]) -> None: - model_configs = model_configs or [] - self._model_configs = {mc.alias: mc for mc in model_configs} - # Models are now lazily initialized in get_model() when first requested + def _set_model_configs(self, model_configs: list[ModelConfig] | None) -> None: + self._model_configs = {mc.alias: mc for mc in (model_configs or [])} + + def close(self) -> None: + """Release resources held by all model facades.""" + for facade in self._models.values(): + facade.close() + + async def aclose(self) -> None: + """Async release resources held by all model facades.""" + for facade in self._models.values(): + await facade.aclose() def _get_model(self, model_config: ModelConfig) -> ModelFacade: if self._model_facade_factory is None: diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py b/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py index 3d01db6ac..a3380a140 100644 --- a/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py @@ -10,6 +10,7 @@ StubMCPRegistry, StubMessage, StubResponse, + make_stub_completion_response, ) from data_designer.engine.testing.utils import assert_valid_plugin @@ -21,4 +22,5 @@ "StubMessage", "StubResponse", assert_valid_plugin.__name__, + make_stub_completion_response.__name__, ] diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py b/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py index af7d3ebfc..e47b2ffbc 100644 --- a/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py +++ b/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py @@ -16,8 +16,9 @@ from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, ToolConfig from data_designer.engine.mcp.facade import MCPFacade from data_designer.engine.model_provider import MCPProviderRegistry +from data_designer.engine.models.clients.types import AssistantMessage, ChatCompletionResponse, ToolCall from data_designer.engine.secret_resolver import SecretResolver -from data_designer.engine.testing.stubs import StubHuggingFaceSeedReader, StubMessage, StubResponse +from data_designer.engine.testing.stubs import StubHuggingFaceSeedReader # ============================================================================= # Seed reader fixtures @@ -151,61 +152,66 @@ def factory( # ============================================================================= -# Completion response fixtures +# Completion response fixtures (canonical ChatCompletionResponse) # ============================================================================= @pytest.fixture -def mock_completion_response_no_tools() -> StubResponse: +def mock_completion_response_no_tools() -> ChatCompletionResponse: """Mock LLM response with no tool calls.""" - return StubResponse(StubMessage(content="Hello, I can help with that.")) + return ChatCompletionResponse( + message=AssistantMessage(content="Hello, I can help with that."), + ) @pytest.fixture -def mock_completion_response_single_tool() -> StubResponse: +def mock_completion_response_single_tool() -> ChatCompletionResponse: """Mock LLM response with single tool call.""" - tool_call = { - "id": "call-1", - "type": "function", - "function": {"name": "lookup", "arguments": '{"query": "test"}'}, - } - return StubResponse(StubMessage(content="Let me look that up.", tool_calls=[tool_call])) + return ChatCompletionResponse( + message=AssistantMessage( + content="Let me look that up.", + tool_calls=[ + ToolCall(id="call-1", name="lookup", arguments_json='{"query": "test"}'), + ], + ), + ) @pytest.fixture -def mock_completion_response_parallel_tools() -> StubResponse: +def mock_completion_response_parallel_tools() -> ChatCompletionResponse: """Mock LLM response with multiple parallel tool calls.""" - tool_calls = [ - {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"query": "first"}'}}, - {"id": "call-2", "type": "function", "function": {"name": "search", "arguments": '{"term": "second"}'}}, - {"id": "call-3", "type": "function", "function": {"name": "fetch", "arguments": '{"url": "example.com"}'}}, - ] - return StubResponse(StubMessage(content="Executing multiple tools.", tool_calls=tool_calls)) + return ChatCompletionResponse( + message=AssistantMessage( + content="Executing multiple tools.", + tool_calls=[ + ToolCall(id="call-1", name="lookup", arguments_json='{"query": "first"}'), + ToolCall(id="call-2", name="search", arguments_json='{"term": "second"}'), + ToolCall(id="call-3", name="fetch", arguments_json='{"url": "example.com"}'), + ], + ), + ) @pytest.fixture -def mock_completion_response_with_reasoning() -> StubResponse: +def mock_completion_response_with_reasoning() -> ChatCompletionResponse: """Mock LLM response with reasoning_content.""" - return StubResponse( - StubMessage( + return ChatCompletionResponse( + message=AssistantMessage( content=" Final answer with extra spaces. ", reasoning_content=" Thinking about the problem... ", - ) + ), ) @pytest.fixture -def mock_completion_response_tool_with_reasoning() -> StubResponse: +def mock_completion_response_tool_with_reasoning() -> ChatCompletionResponse: """Mock LLM response with tool calls and reasoning_content.""" - tool_call = { - "id": "call-1", - "type": "function", - "function": {"name": "lookup", "arguments": '{"query": "test"}'}, - } - return StubResponse( - StubMessage( + return ChatCompletionResponse( + message=AssistantMessage( content=" Looking it up... ", - tool_calls=[tool_call], + tool_calls=[ + ToolCall(id="call-1", name="lookup", arguments_json='{"query": "test"}'), + ], reasoning_content=" I should use the lookup tool. ", - ) + ), ) diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py b/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py index 60567b51e..4807b64c8 100644 --- a/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py +++ b/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py @@ -8,6 +8,7 @@ from data_designer.config.base import ConfigBase, SingleColumnConfig from data_designer.engine.column_generators.generators.base import ColumnGeneratorCellByCell +from data_designer.engine.models.clients.types import AssistantMessage, ChatCompletionResponse, ToolCall from data_designer.engine.models.utils import ChatMessage from data_designer.engine.resources.seed_reader import SeedReader from data_designer.plugins.plugin import Plugin, PluginType @@ -24,7 +25,7 @@ def get_column_names(self) -> list[str]: def get_dataset_uri(self) -> str: return "unused in these tests" - def create_duckdb_connection(self): + def create_duckdb_connection(self) -> None: pass def get_seed_type(self) -> str: @@ -41,7 +42,7 @@ class ValidTestConfig(SingleColumnConfig): class ValidTestTask(ColumnGeneratorCellByCell[ValidTestConfig]): """Valid task for testing plugin creation.""" - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data @@ -70,12 +71,12 @@ class StubPluginConfigB(SingleColumnConfig): class StubPluginTaskA(ColumnGeneratorCellByCell[StubPluginConfigA]): - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data class StubPluginTaskB(ColumnGeneratorCellByCell[StubPluginConfigB]): - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data @@ -95,17 +96,17 @@ class StubPluginConfigBlobsAndSeeds(SingleColumnConfig): class StubPluginTaskModels(ColumnGeneratorCellByCell[StubPluginConfigModels]): - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data class StubPluginTaskModelsAndBlobs(ColumnGeneratorCellByCell[StubPluginConfigModelsAndBlobs]): - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data class StubPluginTaskBlobsAndSeeds(ColumnGeneratorCellByCell[StubPluginConfigBlobsAndSeeds]): - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data @@ -135,7 +136,7 @@ def generate(self, data: dict) -> dict: # ============================================================================= -# Stub LLM response classes for testing +# Stub LLM response classes for testing (legacy, kept for backward compat) # ============================================================================= @@ -173,6 +174,26 @@ def __init__(self, message: StubMessage) -> None: self.choices = [StubChoice(message)] +# ============================================================================= +# Canonical stub helpers +# ============================================================================= + + +def make_stub_completion_response( + content: str | None = None, + reasoning_content: str | None = None, + tool_calls: list[ToolCall] | None = None, +) -> ChatCompletionResponse: + """Factory helper for creating canonical ChatCompletionResponse test objects.""" + return ChatCompletionResponse( + message=AssistantMessage( + content=content, + reasoning_content=reasoning_content, + tool_calls=tool_calls or [], + ), + ) + + # ============================================================================= # Stub MCP classes for testing tool calling # ============================================================================= @@ -195,8 +216,8 @@ def __init__( self, max_tool_call_turns: int = 3, tool_schemas: list[dict[str, Any]] | None = None, - process_fn: Callable[[Any], list[ChatMessage]] | None = None, - refuse_fn: Callable[[Any], list[ChatMessage]] | None = None, + process_fn: Callable[[ChatCompletionResponse], list[ChatMessage]] | None = None, + refuse_fn: Callable[[ChatCompletionResponse], list[ChatMessage]] | None = None, ) -> None: self.tool_alias = "tools" self.providers = ["tools"] @@ -208,34 +229,49 @@ def __init__( def get_tool_schemas(self) -> list[dict[str, Any]]: return self._tool_schemas - def tool_call_count(self, completion_response: Any) -> int: - tool_calls = getattr(completion_response.choices[0].message, "tool_calls", None) - return len(tool_calls) if tool_calls else 0 + def tool_call_count(self, completion_response: ChatCompletionResponse) -> int: + return len(completion_response.message.tool_calls) - def has_tool_calls(self, completion_response: Any) -> bool: - return completion_response.choices[0].message.tool_calls is not None + def has_tool_calls(self, completion_response: ChatCompletionResponse) -> bool: + return len(completion_response.message.tool_calls) > 0 - def process_completion_response(self, completion_response: Any) -> list[ChatMessage]: + def process_completion_response(self, completion_response: ChatCompletionResponse) -> list[ChatMessage]: if self._process_fn: return self._process_fn(completion_response) - message = completion_response.choices[0].message - tool_calls = message.tool_calls or [] + message = completion_response.message + tool_calls = message.tool_calls + tool_call_dicts = [ + { + "id": tc.id, + "type": "function", + "function": {"name": tc.name, "arguments": tc.arguments_json}, + } + for tc in tool_calls + ] return [ - ChatMessage.as_assistant(content=message.content or "", tool_calls=tool_calls), - *[ChatMessage.as_tool(content="tool-result", tool_call_id=tc["id"]) for tc in tool_calls], + ChatMessage.as_assistant(content=message.content or "", tool_calls=tool_call_dicts), + *[ChatMessage.as_tool(content="tool-result", tool_call_id=tc.id) for tc in tool_calls], ] - def refuse_completion_response(self, completion_response: Any) -> list[ChatMessage]: + def refuse_completion_response(self, completion_response: ChatCompletionResponse) -> list[ChatMessage]: if self._refuse_fn: return self._refuse_fn(completion_response) - message = completion_response.choices[0].message - tool_calls = message.tool_calls or [] + message = completion_response.message + tool_calls = message.tool_calls + tool_call_dicts = [ + { + "id": tc.id, + "type": "function", + "function": {"name": tc.name, "arguments": tc.arguments_json}, + } + for tc in tool_calls + ] return [ - ChatMessage.as_assistant(content="", tool_calls=tool_calls), + ChatMessage.as_assistant(content="", tool_calls=tool_call_dicts), *[ ChatMessage.as_tool( content="Tool call refused: maximum tool-calling turns reached.", - tool_call_id=tc["id"], + tool_call_id=tc.id, ) for tc in tool_calls ], @@ -243,14 +279,7 @@ def refuse_completion_response(self, completion_response: Any) -> list[ChatMessa class StubMCPRegistry: - """Stub MCP registry that returns a configurable StubMCPFacade. - - This stub provides a simple registry implementation for testing that - returns the configured StubMCPFacade instance. - - Args: - facade: The StubMCPFacade instance to return. If None, creates a default one. - """ + """Stub MCP registry that returns a configurable StubMCPFacade.""" def __init__(self, facade: StubMCPFacade | None = None) -> None: self._facade = facade or StubMCPFacade() diff --git a/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py b/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py index 610fa53f9..5dae4fdbc 100644 --- a/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py +++ b/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py @@ -10,39 +10,26 @@ from data_designer.config.mcp import LocalStdioMCPProvider, ToolConfig from data_designer.engine.mcp import io as mcp_io -from data_designer.engine.mcp.errors import DuplicateToolNameError, MCPToolError +from data_designer.engine.mcp.errors import DuplicateToolNameError, MCPConfigurationError, MCPToolError from data_designer.engine.mcp.facade import DEFAULT_TOOL_REFUSAL_MESSAGE, MCPFacade from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult from data_designer.engine.model_provider import MCPProviderRegistry - - -# Fake classes are used directly in tests to create custom responses -class FakeMessage: - """Fake message class for mocking LLM completion responses.""" - - def __init__( - self, - content: str | None, - tool_calls: list[dict] | None = None, - reasoning_content: str | None = None, - ) -> None: - self.content = content - self.tool_calls = tool_calls - self.reasoning_content = reasoning_content - - -class FakeChoice: - """Fake choice class for mocking LLM completion responses.""" - - def __init__(self, message: FakeMessage) -> None: - self.message = message - - -class FakeResponse: - """Fake response class for mocking LLM completion responses.""" - - def __init__(self, message: FakeMessage) -> None: - self.choices = [FakeChoice(message)] +from data_designer.engine.models.clients.types import AssistantMessage, ChatCompletionResponse, ToolCall + + +def _make_response( + content: str | None = None, + tool_calls: list[ToolCall] | None = None, + reasoning_content: str | None = None, +) -> ChatCompletionResponse: + """Shorthand for creating canonical test responses.""" + return ChatCompletionResponse( + message=AssistantMessage( + content=content, + reasoning_content=reasoning_content, + tool_calls=tool_calls or [], + ), + ) # ============================================================================= @@ -50,24 +37,24 @@ def __init__(self, message: FakeMessage) -> None: # ============================================================================= -def test_tool_call_count_no_tools(mock_completion_response_no_tools: FakeResponse) -> None: +def test_tool_call_count_no_tools(mock_completion_response_no_tools: ChatCompletionResponse) -> None: """Returns 0 when response has no tool calls.""" assert MCPFacade.tool_call_count(mock_completion_response_no_tools) == 0 -def test_tool_call_count_single_tool(mock_completion_response_single_tool: FakeResponse) -> None: +def test_tool_call_count_single_tool(mock_completion_response_single_tool: ChatCompletionResponse) -> None: """Returns 1 for single tool call.""" assert MCPFacade.tool_call_count(mock_completion_response_single_tool) == 1 -def test_tool_call_count_parallel_tools(mock_completion_response_parallel_tools: FakeResponse) -> None: +def test_tool_call_count_parallel_tools(mock_completion_response_parallel_tools: ChatCompletionResponse) -> None: """Returns correct count for parallel tool calls (e.g., 3).""" assert MCPFacade.tool_call_count(mock_completion_response_parallel_tools) == 3 def test_tool_call_count_none_tool_calls_attribute() -> None: - """Returns 0 when tool_calls attribute is None.""" - response = FakeResponse(FakeMessage(content="Hello", tool_calls=None)) + """Returns 0 when tool_calls is empty.""" + response = _make_response(content="Hello") assert MCPFacade.tool_call_count(response) == 0 @@ -76,12 +63,12 @@ def test_tool_call_count_none_tool_calls_attribute() -> None: # ============================================================================= -def test_has_tool_calls_true(mock_completion_response_single_tool: FakeResponse) -> None: +def test_has_tool_calls_true(mock_completion_response_single_tool: ChatCompletionResponse) -> None: """Returns True when tool calls are present.""" assert MCPFacade.has_tool_calls(mock_completion_response_single_tool) is True -def test_has_tool_calls_false(mock_completion_response_no_tools: FakeResponse) -> None: +def test_has_tool_calls_false(mock_completion_response_no_tools: ChatCompletionResponse) -> None: """Returns False when no tool calls are present.""" assert MCPFacade.has_tool_calls(mock_completion_response_no_tools) is False @@ -93,7 +80,7 @@ def test_has_tool_calls_false(mock_completion_response_no_tools: FakeResponse) - def test_process_completion_no_tool_calls( stub_mcp_facade: MCPFacade, - mock_completion_response_no_tools: FakeResponse, + mock_completion_response_no_tools: ChatCompletionResponse, ) -> None: """Returns [assistant_message] when no tool calls present.""" messages = stub_mcp_facade.process_completion_response(mock_completion_response_no_tools) @@ -107,7 +94,7 @@ def test_process_completion_no_tool_calls( def test_process_completion_with_tool_calls( monkeypatch: pytest.MonkeyPatch, stub_mcp_facade: MCPFacade, - mock_completion_response_single_tool: FakeResponse, + mock_completion_response_single_tool: ChatCompletionResponse, ) -> None: """Returns [assistant_msg, tool_msg] for tool calls.""" @@ -137,7 +124,7 @@ def mock_call_tools( def test_process_completion_preserves_content( stub_mcp_facade: MCPFacade, - mock_completion_response_no_tools: FakeResponse, + mock_completion_response_no_tools: ChatCompletionResponse, ) -> None: """Assistant content is preserved in returned message.""" messages = stub_mcp_facade.process_completion_response(mock_completion_response_no_tools) @@ -147,7 +134,7 @@ def test_process_completion_preserves_content( def test_process_completion_preserves_reasoning_content( stub_mcp_facade: MCPFacade, - mock_completion_response_with_reasoning: FakeResponse, + mock_completion_response_with_reasoning: ChatCompletionResponse, ) -> None: """Reasoning content is preserved when present.""" messages = stub_mcp_facade.process_completion_response(mock_completion_response_with_reasoning) @@ -158,7 +145,7 @@ def test_process_completion_preserves_reasoning_content( def test_process_completion_strips_whitespace_with_reasoning( stub_mcp_facade: MCPFacade, - mock_completion_response_with_reasoning: FakeResponse, + mock_completion_response_with_reasoning: ChatCompletionResponse, ) -> None: """Content and reasoning are stripped when reasoning is present.""" messages = stub_mcp_facade.process_completion_response(mock_completion_response_with_reasoning) @@ -170,7 +157,7 @@ def test_process_completion_strips_whitespace_with_reasoning( def test_process_completion_parallel_tool_calls( monkeypatch: pytest.MonkeyPatch, stub_mcp_facade: MCPFacade, - mock_completion_response_parallel_tools: FakeResponse, + mock_completion_response_parallel_tools: ChatCompletionResponse, ) -> None: """All parallel tool calls are executed and messages returned.""" @@ -222,13 +209,10 @@ def test_process_completion_tool_not_in_allow_list( mcp_provider_registry=stub_mcp_provider_registry, ) - # Tool "forbidden" is not in allow_tools ["lookup", "search"] - tool_call = { - "id": "call-1", - "type": "function", - "function": {"name": "forbidden", "arguments": "{}"}, - } - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) + response = _make_response( + content="", + tool_calls=[ToolCall(id="call-1", name="forbidden", arguments_json="{}")], + ) with pytest.raises(MCPToolError, match="not permitted"): facade.process_completion_response(response) @@ -253,8 +237,10 @@ def mock_call_tools( monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools) - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} - response = FakeResponse(FakeMessage(content=None, tool_calls=[tool_call])) + response = _make_response( + content=None, + tool_calls=[ToolCall(id="call-1", name="lookup", arguments_json="{}")], + ) messages = stub_mcp_facade.process_completion_response(response) @@ -270,7 +256,7 @@ def mock_call_tools( def test_refuse_completion_no_tool_calls( stub_mcp_facade: MCPFacade, - mock_completion_response_no_tools: FakeResponse, + mock_completion_response_no_tools: ChatCompletionResponse, ) -> None: """Returns [assistant_message] when no tool calls to refuse.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_no_tools) @@ -282,7 +268,7 @@ def test_refuse_completion_no_tool_calls( def test_refuse_completion_single_tool( stub_mcp_facade: MCPFacade, - mock_completion_response_single_tool: FakeResponse, + mock_completion_response_single_tool: ChatCompletionResponse, ) -> None: """Returns assistant + refusal message for single tool call.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_single_tool) @@ -297,7 +283,7 @@ def test_refuse_completion_single_tool( def test_refuse_completion_parallel_tools( stub_mcp_facade: MCPFacade, - mock_completion_response_parallel_tools: FakeResponse, + mock_completion_response_parallel_tools: ChatCompletionResponse, ) -> None: """Returns assistant + refusal for each parallel tool call.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_parallel_tools) @@ -313,7 +299,7 @@ def test_refuse_completion_parallel_tools( def test_refuse_completion_default_message( stub_mcp_facade: MCPFacade, - mock_completion_response_single_tool: FakeResponse, + mock_completion_response_single_tool: ChatCompletionResponse, ) -> None: """Uses default refusal message when none provided.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_single_tool) @@ -323,7 +309,7 @@ def test_refuse_completion_default_message( def test_refuse_completion_custom_message( stub_mcp_facade: MCPFacade, - mock_completion_response_single_tool: FakeResponse, + mock_completion_response_single_tool: ChatCompletionResponse, ) -> None: """Uses custom refusal message when provided.""" custom_message = "Custom refusal: Budget exceeded." @@ -337,12 +323,11 @@ def test_refuse_completion_custom_message( def test_refuse_completion_preserves_tool_call_ids( stub_mcp_facade: MCPFacade, - mock_completion_response_parallel_tools: FakeResponse, + mock_completion_response_parallel_tools: ChatCompletionResponse, ) -> None: """Refusal messages have correct tool_call_id linkage.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_parallel_tools) - # Verify each refusal message has the correct tool_call_id assert messages[1].tool_call_id == "call-1" assert messages[2].tool_call_id == "call-2" assert messages[3].tool_call_id == "call-3" @@ -350,7 +335,7 @@ def test_refuse_completion_preserves_tool_call_ids( def test_refuse_completion_preserves_reasoning( stub_mcp_facade: MCPFacade, - mock_completion_response_tool_with_reasoning: FakeResponse, + mock_completion_response_tool_with_reasoning: ChatCompletionResponse, ) -> None: """Reasoning content preserved in refusal scenario.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_tool_with_reasoning) @@ -363,7 +348,7 @@ def test_refuse_completion_preserves_reasoning( def test_refuse_does_not_call_mcp_server( monkeypatch: pytest.MonkeyPatch, stub_mcp_facade: MCPFacade, - mock_completion_response_single_tool: FakeResponse, + mock_completion_response_single_tool: ChatCompletionResponse, ) -> None: """Verify MCP server is NOT called during refusal.""" call_tools_called = False @@ -484,40 +469,20 @@ def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MC monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) - from data_designer.engine.mcp.errors import MCPConfigurationError - with pytest.raises(MCPConfigurationError, match="not found"): facade.get_tool_schemas() # ============================================================================= -# Tool call normalization via public API (process_completion_response) +# Tool call handling via public API (process_completion_response) # ============================================================================= -def test_process_completion_missing_tool_name(stub_mcp_facade: MCPFacade) -> None: - """process_completion_response raises MCPToolError when tool call has no name.""" - tool_call = {"id": "call-1", "function": {"arguments": "{}"}} # Missing name - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) - - with pytest.raises(MCPToolError, match="missing a tool name"): - stub_mcp_facade.process_completion_response(response) - - -def test_process_completion_invalid_json_arguments(stub_mcp_facade: MCPFacade) -> None: - """process_completion_response raises MCPToolError when arguments are invalid JSON.""" - tool_call = {"id": "call-1", "function": {"name": "lookup", "arguments": "not valid json"}} - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) - - with pytest.raises(MCPToolError, match="Invalid tool arguments"): - stub_mcp_facade.process_completion_response(response) - - -def test_process_completion_dict_arguments( +def test_process_completion_with_empty_arguments( monkeypatch: pytest.MonkeyPatch, stub_mcp_facade: MCPFacade, ) -> None: - """process_completion_response handles dict arguments correctly.""" + """process_completion_response handles empty arguments gracefully.""" def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),) @@ -536,21 +501,22 @@ def mock_call_tools( monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools) - # Pass dict arguments (not JSON string) - tool_call = {"id": "call-1", "function": {"name": "lookup", "arguments": {"query": "test"}}} - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) + response = _make_response( + content="", + tool_calls=[ToolCall(id="call-1", name="lookup", arguments_json="{}")], + ) messages = stub_mcp_facade.process_completion_response(response) assert len(messages) == 2 - assert captured_args[0] == {"query": "test"} + assert captured_args[0] == {} -def test_process_completion_empty_arguments( +def test_process_completion_with_dict_arguments( monkeypatch: pytest.MonkeyPatch, stub_mcp_facade: MCPFacade, ) -> None: - """process_completion_response handles None/empty arguments gracefully.""" + """process_completion_response handles arguments via canonical ToolCall correctly.""" def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),) @@ -569,85 +535,15 @@ def mock_call_tools( monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools) - tool_call = {"id": "call-1", "function": {"name": "lookup", "arguments": None}} - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) - - messages = stub_mcp_facade.process_completion_response(response) - - assert len(messages) == 2 - assert captured_args[0] == {} # Empty dict for None arguments - - -def test_process_completion_generates_tool_call_id( - monkeypatch: pytest.MonkeyPatch, - stub_mcp_facade: MCPFacade, -) -> None: - """process_completion_response generates UUID for tool calls without ID.""" - - def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: - return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),) - - def mock_call_tools( - calls: list[tuple[Any, str, dict[str, Any]]], - *, - timeout_sec: float | None = None, - ) -> list[MCPToolResult]: - return [MCPToolResult(content="result") for _ in calls] - - monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) - monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools) - - # Tool call without id - tool_call = {"function": {"name": "lookup", "arguments": "{}"}} - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) - - messages = stub_mcp_facade.process_completion_response(response) - - # Should have generated an ID - assert len(messages) == 2 - assert messages[1].tool_call_id is not None - assert len(messages[1].tool_call_id) == 32 # UUID hex format - - -def test_process_completion_object_format_tool_calls( - monkeypatch: pytest.MonkeyPatch, - stub_mcp_facade: MCPFacade, -) -> None: - """process_completion_response handles object format tool calls.""" - - def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: - return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),) - - captured_calls: list[tuple[str, dict[str, Any]]] = [] - - def mock_call_tools( - calls: list[tuple[Any, str, dict[str, Any]]], - *, - timeout_sec: float | None = None, - ) -> list[MCPToolResult]: - for _, tool_name, args in calls: - captured_calls.append((tool_name, args)) - return [MCPToolResult(content="result") for _ in calls] - - monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) - monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools) - - # Create object format tool call (simulating what some LLM libraries return) - class FakeFunction: - name = "lookup" - arguments = '{"query": "test"}' - - class FakeToolCall: - id = "call-obj-1" - function = FakeFunction() - - response = FakeResponse(FakeMessage(content="", tool_calls=[FakeToolCall()])) + response = _make_response( + content="", + tool_calls=[ToolCall(id="call-1", name="lookup", arguments_json='{"query": "test"}')], + ) messages = stub_mcp_facade.process_completion_response(response) assert len(messages) == 2 - assert captured_calls[0] == ("lookup", {"query": "test"}) - assert messages[1].tool_call_id == "call-obj-1" + assert captured_args[0] == {"query": "test"} # ============================================================================= @@ -771,7 +667,6 @@ def test_get_tool_schemas_duplicate_tool_names_raises_error( ) def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: - # Both providers have a tool named "lookup" if provider.name == "tools": return ( MCPToolDefinition(name="lookup", description="Lookup from tools", input_schema={"type": "object"}), @@ -804,7 +699,6 @@ def test_get_tool_schemas_duplicate_tool_names_reports_all_duplicates( ) def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: - # Both providers have "lookup" and "search" as duplicates if provider.name == "tools": return ( MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}), @@ -820,7 +714,6 @@ def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MC with pytest.raises(DuplicateToolNameError) as exc_info: facade.get_tool_schemas() - # Both duplicates should be reported assert "lookup" in str(exc_info.value) assert "search" in str(exc_info.value) @@ -841,14 +734,12 @@ def test_get_tool_schemas_no_duplicates_passes( ) def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: - # Each provider has unique tool names if provider.name == "tools": return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),) return (MCPToolDefinition(name="fetch", description="Fetch", input_schema={"type": "object"}),) monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) - # Should not raise schemas = facade.get_tool_schemas() assert len(schemas) == 2 @@ -867,6 +758,5 @@ def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MC monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) - # Should not raise schemas = stub_mcp_facade.get_tool_schemas() assert len(schemas) == 2 diff --git a/packages/data-designer-engine/tests/engine/models/conftest.py b/packages/data-designer-engine/tests/engine/models/conftest.py index 601dbb4ad..965d815e4 100644 --- a/packages/data-designer-engine/tests/engine/models/conftest.py +++ b/packages/data-designer-engine/tests/engine/models/conftest.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path +from unittest.mock import MagicMock import pytest @@ -11,6 +12,7 @@ ModelConfig, ) from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry +from data_designer.engine.models.clients.base import ModelClient from data_designer.engine.models.factory import create_model_registry from data_designer.engine.models.registry import ModelRegistry from data_designer.engine.secret_resolver import SecretsFileResolver @@ -68,7 +70,11 @@ def stub_model_configs() -> list[ModelConfig]: @pytest.fixture -def stub_model_registry(stub_model_configs, stub_secrets_resolver, stub_model_provider_registry) -> ModelRegistry: +def stub_model_registry( + stub_model_configs: list[ModelConfig], + stub_secrets_resolver: SecretsFileResolver, + stub_model_provider_registry: ModelProviderRegistry, +) -> ModelRegistry: return create_model_registry( model_configs=stub_model_configs, secret_resolver=stub_secrets_resolver, @@ -76,6 +82,12 @@ def stub_model_registry(stub_model_configs, stub_secrets_resolver, stub_model_pr ) +@pytest.fixture +def stub_model_client() -> MagicMock: + """Mock ModelClient for testing ModelFacade without a real LiteLLM router.""" + return MagicMock(spec=ModelClient) + + @pytest.fixture def stub_mcp_facade_for_model() -> StubMCPFacade: """Default stub MCP facade with max_tool_call_turns=3.""" diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 662cbc762..ae92682a0 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -3,33 +3,41 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest -import data_designer.lazy_heavy_imports as lazy from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError +from data_designer.engine.models.clients.types import ( + ChatCompletionResponse, + EmbeddingResponse, + ImageGenerationResponse, + ImagePayload, + ToolCall, +) from data_designer.engine.models.errors import ImageGenerationError, ModelGenerationValidationFailureError -from data_designer.engine.models.facade import CustomRouter, ModelFacade +from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.parsers.errors import ParserException from data_designer.engine.models.utils import ChatMessage -from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, StubMessage, StubResponse - -if TYPE_CHECKING: - from litellm.types.utils import EmbeddingResponse, ModelResponse +from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, make_stub_completion_response -def mock_oai_response_object(response_text: str) -> StubResponse: - return StubResponse(StubMessage(content=response_text)) +def _make_response(content: str | None = None, **kwargs: Any) -> ChatCompletionResponse: + """Shorthand for creating a ChatCompletionResponse in tests.""" + return make_stub_completion_response(content=content, **kwargs) @pytest.fixture -def stub_model_facade(stub_model_configs, stub_secrets_resolver, stub_model_provider_registry): +def stub_model_facade( + stub_model_configs: list[Any], + stub_model_client: MagicMock, + stub_model_provider_registry: Any, +) -> ModelFacade: return ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, ) @@ -38,18 +46,6 @@ def stub_completion_messages() -> list[ChatMessage]: return [ChatMessage.as_user("test")] -@pytest.fixture -def stub_expected_completion_response(): - return lazy.litellm.types.utils.ModelResponse( - choices=lazy.litellm.types.utils.Choices(message=lazy.litellm.types.utils.Message(content="Test response")) - ) - - -@pytest.fixture -def stub_expected_embedding_response(): - return lazy.litellm.types.utils.EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2) - - @pytest.mark.parametrize( "max_correction_steps,max_conversation_restarts,total_calls", [ @@ -69,7 +65,7 @@ def test_generate( max_conversation_restarts: int, total_calls: int, ) -> None: - bad_response = mock_oai_response_object("bad response") + bad_response = _make_response("bad response") mock_completion.side_effect = lambda *args, **kwargs: bad_response def _failing_parser(response: str) -> str: @@ -110,14 +106,11 @@ def test_generate_with_system_prompt( system_prompt: str, expected_messages: list[ChatMessage], ) -> None: - # Capture messages at call time since they get mutated after the call captured_messages = [] - def capture_and_return(*args: Any, **kwargs: Any) -> ModelResponse: - captured_messages.append(list(args[1])) # Copy the messages list - return lazy.litellm.types.utils.ModelResponse( - choices=lazy.litellm.types.utils.Choices(message=lazy.litellm.types.utils.Message(content="Hello!")) - ) + def capture_and_return(*args: Any, **kwargs: Any) -> ChatCompletionResponse: + captured_messages.append(list(args[1])) + return _make_response("Hello!") mock_completion.side_effect = capture_and_return @@ -143,21 +136,21 @@ def test_generate_strips_response_content( expected: str, ) -> None: """Response content from the LLM is stripped of leading/trailing whitespace.""" - mock_completion.side_effect = lambda *args, **kwargs: StubResponse(StubMessage(content=raw_content)) + mock_completion.side_effect = lambda *args, **kwargs: _make_response(raw_content) result, _ = stub_model_facade.generate(prompt="test", parser=lambda x: x) assert result == expected -def test_model_alias_property(stub_model_facade, stub_model_configs): +def test_model_alias_property(stub_model_facade: ModelFacade, stub_model_configs: list[Any]) -> None: assert stub_model_facade.model_alias == stub_model_configs[0].alias -def test_usage_stats_property(stub_model_facade): +def test_usage_stats_property(stub_model_facade: ModelFacade) -> None: assert stub_model_facade.usage_stats is not None assert hasattr(stub_model_facade.usage_stats, "model_dump") -def test_consolidate_kwargs(stub_model_configs, stub_model_facade): +def test_consolidate_kwargs(stub_model_configs: list[Any], stub_model_facade: ModelFacade) -> None: # Model config generate kwargs are used as base, and purpose is removed result = stub_model_facade.consolidate_kwargs(purpose="test") assert result == stub_model_configs[0].inference_parameters.generate_kwargs @@ -191,126 +184,105 @@ def test_consolidate_kwargs(stub_model_configs, stub_model_facade): True, ], ) -@patch.object(CustomRouter, "completion", autospec=True) def test_completion_success( - mock_router_completion: Any, stub_completion_messages: list[ChatMessage], stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_completion_response: ModelResponse, + stub_model_client: MagicMock, skip_usage_tracking: bool, ) -> None: - mock_router_completion.side_effect = lambda self, model, messages, **kwargs: stub_expected_completion_response + expected_response = _make_response("Test response") + stub_model_client.completion.return_value = expected_response result = stub_model_facade.completion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking) - expected_messages = [message.to_dict() for message in stub_completion_messages] - assert result == stub_expected_completion_response - assert mock_router_completion.call_count == 1 - assert mock_router_completion.call_args[1] == { - "model": "stub-model-text", - "messages": expected_messages, - **stub_model_configs[0].inference_parameters.generate_kwargs, - } + assert result == expected_response + assert stub_model_client.completion.call_count == 1 -@patch.object(CustomRouter, "completion", autospec=True) def test_completion_with_exception( - mock_router_completion: Any, stub_completion_messages: list[ChatMessage], stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: - mock_router_completion.side_effect = Exception("Router error") + stub_model_client.completion.side_effect = Exception("Router error") with pytest.raises(Exception, match="Router error"): stub_model_facade.completion(stub_completion_messages) -@patch.object(CustomRouter, "completion", autospec=True) def test_completion_with_kwargs( - mock_router_completion: Any, stub_completion_messages: list[ChatMessage], stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_completion_response: ModelResponse, + stub_model_client: MagicMock, ) -> None: - captured_kwargs = {} - - def mock_completion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> ModelResponse: - captured_kwargs.update(kwargs) - return stub_expected_completion_response - - mock_router_completion.side_effect = mock_completion + expected_response = _make_response("Test response") + stub_model_client.completion.return_value = expected_response kwargs = {"temperature": 0.7, "max_tokens": 100} result = stub_model_facade.completion(stub_completion_messages, **kwargs) - assert result == stub_expected_completion_response - # completion kwargs overrides model config generate kwargs - assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} + assert result == expected_response + assert stub_model_client.completion.call_count == 1 -@patch.object(CustomRouter, "embedding", autospec=True) def test_generate_text_embeddings_success( - mock_router_embedding: Any, stub_model_facade: ModelFacade, - stub_expected_embedding_response: EmbeddingResponse, + stub_model_client: MagicMock, ) -> None: - mock_router_embedding.side_effect = lambda self, model, input, **kwargs: stub_expected_embedding_response + expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + stub_model_client.embeddings.return_value = EmbeddingResponse(vectors=expected_vectors) input_texts = ["test1", "test2"] result = stub_model_facade.generate_text_embeddings(input_texts) - assert result == [data["embedding"] for data in stub_expected_embedding_response.data] + assert result == expected_vectors -@patch.object(CustomRouter, "embedding", autospec=True) -def test_generate_text_embeddings_with_exception(mock_router_embedding: Any, stub_model_facade: ModelFacade) -> None: - mock_router_embedding.side_effect = Exception("Router error") +def test_generate_text_embeddings_with_exception( + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + stub_model_client.embeddings.side_effect = Exception("Router error") with pytest.raises(Exception, match="Router error"): stub_model_facade.generate_text_embeddings(["test1", "test2"]) -@patch.object(CustomRouter, "embedding", autospec=True) def test_generate_text_embeddings_with_kwargs( - mock_router_embedding: Any, stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_embedding_response: EmbeddingResponse, + stub_model_client: MagicMock, ) -> None: - captured_kwargs = {} - - def mock_embedding(self: Any, model: str, input: list[str], **kwargs: Any) -> EmbeddingResponse: - captured_kwargs.update(kwargs) - return stub_expected_embedding_response - - mock_router_embedding.side_effect = mock_embedding + expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + stub_model_client.embeddings.return_value = EmbeddingResponse(vectors=expected_vectors) kwargs = {"temperature": 0.7, "max_tokens": 100, "input_type": "query"} _ = stub_model_facade.generate_text_embeddings(["test1", "test2"], **kwargs) - assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} + assert stub_model_client.embeddings.call_count == 1 def test_generate_with_mcp_tools( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: - tool_call = { - "id": "call-1", - "type": "function", - "function": {"name": "lookup", "arguments": '{"query": "foo"}'}, - } + tool_call = ToolCall(id="call-1", name="lookup", arguments_json='{"query": "foo"}') responses = [ - StubResponse(StubMessage(content=None, tool_calls=[tool_call])), - StubResponse(StubMessage(content="final result")), + _make_response(content=None, tool_calls=[tool_call]), + _make_response("final result"), ] captured_calls: list[tuple[list[ChatMessage], dict[str, Any]]] = [] registry_calls: list[tuple[str, str, dict[str, str], None]] = [] - def process_with_tracking(completion_response: Any) -> list[ChatMessage]: - message = completion_response.choices[0].message + def process_with_tracking(completion_response: ChatCompletionResponse) -> list[ChatMessage]: + message = completion_response.message if not message.tool_calls: return [ChatMessage.as_assistant(content=message.content or "")] registry_calls.append(("tools", "lookup", {"query": "foo"}, None)) + tc_dict = { + "id": "call-1", + "type": "function", + "function": {"name": "lookup", "arguments": '{"query": "foo"}'}, + } return [ - ChatMessage.as_assistant(content="", tool_calls=[tool_call]), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool(content="tool-output", tool_call_id="call-1"), ] @@ -325,14 +297,14 @@ def process_with_tracking(completion_response: Any) -> list[ChatMessage]: ) registry = StubMCPRegistry(facade) - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: captured_calls.append((messages, kwargs)) return responses.pop(0) model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -348,12 +320,12 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_with_tools_missing_registry( - stub_model_configs: Any, stub_secrets_resolver: Any, stub_model_provider_registry: Any + stub_model_configs: Any, stub_model_client: MagicMock, stub_model_provider_registry: Any ) -> None: model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=None, ) @@ -368,32 +340,32 @@ def test_generate_with_tools_missing_registry( def test_generate_with_tool_alias_multiple_turns( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Multiple tool call turns before final response.""" - tool_call_1 = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"query": "foo"}'}} - tool_call_2 = {"id": "call-2", "type": "function", "function": {"name": "search", "arguments": '{"term": "bar"}'}} + tool_call_1 = ToolCall(id="call-1", name="lookup", arguments_json='{"query": "foo"}') + tool_call_2 = ToolCall(id="call-2", name="search", arguments_json='{"term": "bar"}') responses = [ - StubResponse(StubMessage(content="First lookup", tool_calls=[tool_call_1])), - StubResponse(StubMessage(content="Second search", tool_calls=[tool_call_2])), - StubResponse(StubMessage(content="final result after two tool turns")), + _make_response("First lookup", tool_calls=[tool_call_1]), + _make_response("Second search", tool_calls=[tool_call_2]), + _make_response("final result after two tool turns"), ] call_count = 0 facade = StubMCPFacade(max_tool_call_turns=5) registry = StubMCPRegistry(facade) - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal call_count call_count += 1 return responses.pop(0) model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -406,29 +378,29 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_with_tools_tracks_usage_stats( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Tool usage stats are properly tracked with generations_with_tools incremented.""" - tool_call_1 = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"query": "foo"}'}} - tool_call_2 = {"id": "call-2", "type": "function", "function": {"name": "search", "arguments": '{"term": "bar"}'}} + tool_call_1 = ToolCall(id="call-1", name="lookup", arguments_json='{"query": "foo"}') + tool_call_2 = ToolCall(id="call-2", name="search", arguments_json='{"term": "bar"}') responses = [ - StubResponse(StubMessage(content="First lookup", tool_calls=[tool_call_1])), - StubResponse(StubMessage(content="Second search", tool_calls=[tool_call_2])), - StubResponse(StubMessage(content="final result")), + _make_response("First lookup", tool_calls=[tool_call_1]), + _make_response("Second search", tool_calls=[tool_call_2]), + _make_response("final result"), ] facade = StubMCPFacade(max_tool_call_turns=5) registry = StubMCPRegistry(facade) - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: return responses.pop(0) model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -444,15 +416,15 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe assert result == "final result" # Verify tool usage stats are tracked correctly - assert model.usage_stats.tool_usage.total_tool_calls == 2 # 2 tool calls total - assert model.usage_stats.tool_usage.total_tool_call_turns == 2 # 2 turns with tool calls - assert model.usage_stats.tool_usage.total_generations == 1 # 1 generation - assert model.usage_stats.tool_usage.generations_with_tools == 1 # 1 generation with tools + assert model.usage_stats.tool_usage.total_tool_calls == 2 + assert model.usage_stats.tool_usage.total_tool_call_turns == 2 + assert model.usage_stats.tool_usage.total_generations == 1 + assert model.usage_stats.tool_usage.generations_with_tools == 1 def test_generate_with_tools_tracks_multiple_generations( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Tool usage is correctly tracked across multiple generations.""" @@ -461,35 +433,35 @@ def test_generate_with_tools_tracks_multiple_generations( model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) # Generation 1: 2 tool calls across 1 turn - tool_call_a = {"id": "call-a", "type": "function", "function": {"name": "lookup", "arguments": '{"q": "1"}'}} - tool_call_b = {"id": "call-b", "type": "function", "function": {"name": "lookup", "arguments": '{"q": "2"}'}} + tool_call_a = ToolCall(id="call-a", name="lookup", arguments_json='{"q": "1"}') + tool_call_b = ToolCall(id="call-b", name="lookup", arguments_json='{"q": "2"}') responses_gen1 = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call_a, tool_call_b])), - StubResponse(StubMessage(content="result 1")), + _make_response("", tool_calls=[tool_call_a, tool_call_b]), + _make_response("result 1"), ] - def _completion_gen1(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion_gen1(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: return responses_gen1.pop(0) with patch.object(ModelFacade, "completion", new=_completion_gen1): model.generate(prompt="q1", parser=lambda x: x, tool_alias="tools") # Generation 2: 4 tool calls across 2 turns - tool_call_c = {"id": "call-c", "type": "function", "function": {"name": "search", "arguments": '{"q": "3"}'}} - tool_call_d = {"id": "call-d", "type": "function", "function": {"name": "search", "arguments": '{"q": "4"}'}} + tool_call_c = ToolCall(id="call-c", name="search", arguments_json='{"q": "3"}') + tool_call_d = ToolCall(id="call-d", name="search", arguments_json='{"q": "4"}') responses_gen2 = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call_a, tool_call_b])), - StubResponse(StubMessage(content="", tool_calls=[tool_call_c, tool_call_d])), - StubResponse(StubMessage(content="result 2")), + _make_response("", tool_calls=[tool_call_a, tool_call_b]), + _make_response("", tool_calls=[tool_call_c, tool_call_d]), + _make_response("result 2"), ] - def _completion_gen2(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion_gen2(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: return responses_gen2.pop(0) with patch.object(ModelFacade, "completion", new=_completion_gen2): @@ -497,10 +469,10 @@ def _completion_gen2(self: Any, messages: list[ChatMessage], **kwargs: Any) -> S # Generation 3: No tool calls responses_gen3 = [ - StubResponse(StubMessage(content="result 3")), + _make_response("result 3"), ] - def _completion_gen3(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion_gen3(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: return responses_gen3.pop(0) with patch.object(ModelFacade, "completion", new=_completion_gen3): @@ -515,37 +487,36 @@ def _completion_gen3(self: Any, messages: list[ChatMessage], **kwargs: Any) -> S def test_generate_tool_turn_limit_triggers_refusal( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """When max_tool_call_turns exceeded, refusal is used.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") - # Keep returning tool calls to exceed the limit responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Turn 1 - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Turn 2 (max) - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Turn 3 (exceeds, should refuse) - StubResponse(StubMessage(content="final answer after refusal")), + _make_response("", tool_calls=[tool_call]), # Turn 1 + _make_response("", tool_calls=[tool_call]), # Turn 2 (max) + _make_response("", tool_calls=[tool_call]), # Turn 3 (exceeds, should refuse) + _make_response("final answer after refusal"), ] process_calls = 0 refuse_calls = 0 - def custom_process_fn(completion_response: Any) -> list[ChatMessage]: + tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + + def custom_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: nonlocal process_calls process_calls += 1 - message = completion_response.choices[0].message return [ - ChatMessage.as_assistant(content="", tool_calls=message.tool_calls or []), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool(content="tool-result", tool_call_id="call-1"), ] - def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: + def custom_refuse_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: nonlocal refuse_calls refuse_calls += 1 - message = completion_response.choices[0].message return [ - ChatMessage.as_assistant(content="", tool_calls=message.tool_calls or []), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool(content="REFUSED: Budget exceeded", tool_call_id="call-1"), ] @@ -554,7 +525,7 @@ def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -562,8 +533,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -577,20 +548,21 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_tool_turn_limit_model_responds_after_refusal( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Model provides final answer after refusal message.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") + tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Exceeds on first turn - StubResponse(StubMessage(content="I understand, here is the answer without tools")), + _make_response("", tool_calls=[tool_call]), # Exceeds on first turn + _make_response("I understand, here is the answer without tools"), ] - def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: + def custom_refuse_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: return [ - ChatMessage.as_assistant(content="", tool_calls=[tool_call]), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool( content="Tool call refused: You have reached the maximum number of tool-calling turns.", tool_call_id="call-1", @@ -606,7 +578,7 @@ def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -614,8 +586,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -629,20 +601,20 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_tool_alias_not_in_registry( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Raises error when tool_alias not found in MCPRegistry.""" - class StubMCPRegistry: + class _StubMCPRegistry: def get_mcp(self, *, tool_alias: str) -> Any: raise ValueError(f"No tool config with alias {tool_alias!r} found!") model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, - mcp_registry=StubMCPRegistry(), + client=stub_model_client, + mcp_registry=_StubMCPRegistry(), ) with pytest.raises(MCPConfigurationError, match="not registered"): @@ -651,27 +623,27 @@ def get_mcp(self, *, tool_alias: str) -> Any: def test_generate_no_tool_alias_ignores_mcp( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """When tool_alias is None, no MCP operations occur.""" get_mcp_called = False - class StubMCPRegistry: + class _StubMCPRegistry: def get_mcp(self, *, tool_alias: str) -> Any: nonlocal get_mcp_called get_mcp_called = True raise RuntimeError("Should not be called") - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: assert "tools" not in kwargs # No tools should be passed - return StubResponse(StubMessage(content="response without tools")) + return _make_response("response without tools") model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, - mcp_registry=StubMCPRegistry(), + client=stub_model_client, + mcp_registry=_StubMCPRegistry(), ) with patch.object(ModelFacade, "completion", new=_completion): @@ -683,17 +655,17 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_tool_calls_with_parser_corrections( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Tool calling works correctly with parser correction steps.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") parse_count = 0 responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Tool call - StubResponse(StubMessage(content="bad format")), # Parser will fail - StubResponse(StubMessage(content="correct format")), # Parser will succeed + _make_response("", tool_calls=[tool_call]), # Tool call + _make_response("bad format"), # Parser will fail + _make_response("correct format"), # Parser will succeed ] facade = StubMCPFacade() @@ -701,7 +673,7 @@ def test_generate_tool_calls_with_parser_corrections( response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -716,8 +688,8 @@ def _parser(text: str) -> str: model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -730,20 +702,18 @@ def _parser(text: str) -> str: def test_generate_tool_calls_with_conversation_restarts( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Tool calling works correctly with conversation restarts.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") messages_at_call: list[int] = [] - # First conversation: tool call + bad response - # After restart: tool call + good response responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), - StubResponse(StubMessage(content="still bad")), # Fails parser, triggers restart - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # After restart - StubResponse(StubMessage(content="good result")), + _make_response("", tool_calls=[tool_call]), + _make_response("still bad"), # Fails parser, triggers restart + _make_response("", tool_calls=[tool_call]), # After restart + _make_response("good result"), ] facade = StubMCPFacade() @@ -751,7 +721,7 @@ def test_generate_tool_calls_with_conversation_restarts( response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx messages_at_call.append(len(messages)) resp = responses[response_idx] @@ -765,8 +735,8 @@ def _parser(text: str) -> str: model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -777,7 +747,7 @@ def _parser(text: str) -> str: assert result == "good result" # After restart, message count should preserve tool call history (restart from checkpoint) - assert messages_at_call[2] == messages_at_call[1] # Both should be post-tool-call message count + assert messages_at_call[2] == messages_at_call[1] # ============================================================================= @@ -787,15 +757,15 @@ def _parser(text: str) -> str: def test_generate_trace_includes_tool_calls( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Returned trace includes tool call messages.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"q": "test"}'}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json='{"q": "test"}') responses = [ - StubResponse(StubMessage(content="Let me look that up", tool_calls=[tool_call])), - StubResponse(StubMessage(content="Here is the answer")), + _make_response("Let me look that up", tool_calls=[tool_call]), + _make_response("Here is the answer"), ] facade = StubMCPFacade() @@ -803,7 +773,7 @@ def test_generate_trace_includes_tool_calls( response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -811,8 +781,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -827,20 +797,21 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_trace_includes_tool_responses( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Returned trace includes tool response messages.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") + tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), - StubResponse(StubMessage(content="final")), + _make_response("", tool_calls=[tool_call]), + _make_response("final"), ] - def custom_process_fn(completion_response: Any) -> list[ChatMessage]: + def custom_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: return [ - ChatMessage.as_assistant(content="", tool_calls=[tool_call]), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool(content="THE TOOL RESPONSE CONTENT", tool_call_id="call-1"), ] @@ -849,7 +820,7 @@ def custom_process_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -857,8 +828,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -873,20 +844,21 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_trace_includes_refusal_messages( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Returned trace includes refusal messages when budget exhausted.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") + tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Will be refused (max_turns=0) - StubResponse(StubMessage(content="answer without tools")), + _make_response("", tool_calls=[tool_call]), # Will be refused (max_turns=0) + _make_response("answer without tools"), ] - def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: + def custom_refuse_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: return [ - ChatMessage.as_assistant(content="", tool_calls=[tool_call]), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool(content="BUDGET_EXCEEDED_REFUSAL", tool_call_id="call-1"), ] @@ -899,7 +871,7 @@ def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -907,8 +879,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -922,24 +894,22 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_trace_preserves_reasoning_content( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Trace messages preserve reasoning_content field.""" - response = StubResponse( - StubMessage( - content="The answer is 42", - reasoning_content="Let me think about this carefully...", - ) + response = _make_response( + "The answer is 42", + reasoning_content="Let me think about this carefully...", ) - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: return response model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, ) with patch.object(ModelFacade, "completion", new=_completion): @@ -958,15 +928,15 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_tool_execution_error( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Handles MCP tool execution errors appropriately.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") - responses = [StubResponse(StubMessage(content="", tool_calls=[tool_call]))] + responses = [_make_response("", tool_calls=[tool_call])] - def error_process_fn(completion_response: Any) -> list[ChatMessage]: + def error_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: raise MCPToolError("Tool execution failed: Connection refused") facade = StubMCPFacade(process_fn=error_process_fn) @@ -974,7 +944,7 @@ def error_process_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -982,8 +952,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -994,16 +964,15 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_tool_invalid_arguments( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Handles invalid tool arguments from LLM.""" - # Tool call with invalid JSON arguments - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "not valid json"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="not valid json") - responses = [StubResponse(StubMessage(content="", tool_calls=[tool_call]))] + responses = [_make_response("", tool_calls=[tool_call])] - def error_process_fn(completion_response: Any) -> list[ChatMessage]: + def error_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: raise MCPToolError("Invalid tool arguments for 'lookup': not valid json") facade = StubMCPFacade(process_fn=error_process_fn) @@ -1011,7 +980,7 @@ def error_process_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -1019,8 +988,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -1034,252 +1003,107 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe # ============================================================================= -@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) def test_generate_image_diffusion_tracks_image_usage( - mock_image_generation: Any, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test that generate_image tracks image usage for diffusion models.""" - # Mock response with 3 images - mock_response = lazy.litellm.types.utils.ImageResponse( - data=[ - lazy.litellm.types.utils.ImageObject(b64_json="image1_base64"), - lazy.litellm.types.utils.ImageObject(b64_json="image2_base64"), - lazy.litellm.types.utils.ImageObject(b64_json="image3_base64"), + stub_model_client.generate_image.return_value = ImageGenerationResponse( + images=[ + ImagePayload(b64_data="image1_base64"), + ImagePayload(b64_data="image2_base64"), + ImagePayload(b64_data="image3_base64"), ] ) - mock_image_generation.return_value = mock_response - # Verify initial state assert stub_model_facade.usage_stats.image_usage.total_images == 0 - # Generate images with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): images = stub_model_facade.generate_image(prompt="test prompt", n=3) - # Verify results assert len(images) == 3 assert images == ["image1_base64", "image2_base64", "image3_base64"] - - # Verify image usage was tracked assert stub_model_facade.usage_stats.image_usage.total_images == 3 assert stub_model_facade.usage_stats.image_usage.has_usage is True -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) def test_generate_image_chat_completion_tracks_image_usage( - mock_completion: Any, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test that generate_image tracks image usage for chat completion models.""" - # Mock response with images attribute (Message requires type and index per ImageURLListItem) - mock_message = lazy.litellm.types.utils.Message( - role="assistant", - content="", + stub_model_client.generate_image.return_value = ImageGenerationResponse( images=[ - lazy.litellm.types.utils.ImageURLListItem( - type="image_url", image_url={"url": "data:image/png;base64,image1"}, index=0 - ), - lazy.litellm.types.utils.ImageURLListItem( - type="image_url", image_url={"url": "data:image/png;base64,image2"}, index=1 - ), - ], - ) - mock_response = lazy.litellm.types.utils.ModelResponse( - choices=[lazy.litellm.types.utils.Choices(message=mock_message)] + ImagePayload(b64_data="image1"), + ImagePayload(b64_data="image2"), + ] ) - mock_completion.return_value = mock_response - # Verify initial state assert stub_model_facade.usage_stats.image_usage.total_images == 0 - # Generate images with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): images = stub_model_facade.generate_image(prompt="test prompt") - # Verify results assert len(images) == 2 assert images == ["image1", "image2"] - - # Verify image usage was tracked assert stub_model_facade.usage_stats.image_usage.total_images == 2 assert stub_model_facade.usage_stats.image_usage.has_usage is True -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) -def test_generate_image_chat_completion_with_dict_format( - mock_completion: Any, - stub_model_facade: ModelFacade, -) -> None: - """Test that generate_image handles images as dicts with image_url string.""" - # Create mock message with images as dict with string image_url - mock_message = MagicMock() - mock_message.role = "assistant" - mock_message.content = "" - mock_message.images = [ - {"image_url": "data:image/png;base64,image1"}, - {"image_url": "data:image/jpeg;base64,image2"}, - ] - - mock_choice = MagicMock() - mock_choice.message = mock_message - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - - mock_completion.return_value = mock_response - - # Generate images - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): - images = stub_model_facade.generate_image(prompt="test prompt") - - # Verify results - assert len(images) == 2 - assert images == ["image1", "image2"] - - -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) -def test_generate_image_chat_completion_with_plain_strings( - mock_completion: Any, - stub_model_facade: ModelFacade, -) -> None: - """Test that generate_image handles images as plain strings.""" - # Create mock message with images as plain strings - mock_message = MagicMock() - mock_message.role = "assistant" - mock_message.content = "" - mock_message.images = [ - "data:image/png;base64,image1", - "image2", # Plain base64 without data URI prefix - ] - - mock_choice = MagicMock() - mock_choice.message = mock_message - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - - mock_completion.return_value = mock_response - - # Generate images - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): - images = stub_model_facade.generate_image(prompt="test prompt") - - # Verify results - assert len(images) == 2 - assert images == ["image1", "image2"] - - -@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) def test_generate_image_skip_usage_tracking( - mock_image_generation: Any, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test that generate_image respects skip_usage_tracking flag.""" - mock_response = lazy.litellm.types.utils.ImageResponse( - data=[ - lazy.litellm.types.utils.ImageObject(b64_json="image1_base64"), - lazy.litellm.types.utils.ImageObject(b64_json="image2_base64"), + stub_model_client.generate_image.return_value = ImageGenerationResponse( + images=[ + ImagePayload(b64_data="image1_base64"), + ImagePayload(b64_data="image2_base64"), ] ) - mock_image_generation.return_value = mock_response - # Verify initial state assert stub_model_facade.usage_stats.image_usage.total_images == 0 - # Generate images with skip_usage_tracking=True with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): images = stub_model_facade.generate_image(prompt="test prompt", skip_usage_tracking=True) - # Verify results assert len(images) == 2 - - # Verify image usage was NOT tracked assert stub_model_facade.usage_stats.image_usage.total_images == 0 assert stub_model_facade.usage_stats.image_usage.has_usage is False -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) -def test_generate_image_chat_completion_no_choices( - mock_completion: Any, - stub_model_facade: ModelFacade, -) -> None: - """Test that generate_image raises ImageGenerationError when response has no choices.""" - mock_response = lazy.litellm.types.utils.ModelResponse(choices=[]) - mock_completion.return_value = mock_response - - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): - with pytest.raises(ImageGenerationError, match="Image generation response missing choices"): - stub_model_facade.generate_image(prompt="test prompt") - - -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) -def test_generate_image_chat_completion_no_image_data( - mock_completion: Any, +def test_generate_image_no_image_data( stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test that generate_image raises ImageGenerationError when no image data in response.""" - mock_message = lazy.litellm.types.utils.Message(role="assistant", content="just text, no image") - mock_response = lazy.litellm.types.utils.ModelResponse( - choices=[lazy.litellm.types.utils.Choices(message=mock_message)] - ) - mock_completion.return_value = mock_response + stub_model_client.generate_image.return_value = ImageGenerationResponse(images=[]) with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): - with pytest.raises(ImageGenerationError, match="No image data found in image generation response"): + with pytest.raises(ImageGenerationError, match="No image data found"): stub_model_facade.generate_image(prompt="test prompt") -@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) -def test_generate_image_diffusion_no_data( - mock_image_generation: Any, - stub_model_facade: ModelFacade, -) -> None: - """Test that generate_image raises ImageGenerationError when diffusion API returns no data.""" - mock_response = lazy.litellm.types.utils.ImageResponse(data=[]) - mock_image_generation.return_value = mock_response - - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): - with pytest.raises(ImageGenerationError, match="Image generation returned no data"): - stub_model_facade.generate_image(prompt="test prompt") - - -@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) def test_generate_image_accumulates_usage( - mock_image_generation: Any, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test that generate_image accumulates image usage across multiple calls.""" - # First call - 2 images - mock_response1 = lazy.litellm.types.utils.ImageResponse( - data=[ - lazy.litellm.types.utils.ImageObject(b64_json="image1"), - lazy.litellm.types.utils.ImageObject(b64_json="image2"), - ] + response1 = ImageGenerationResponse(images=[ImagePayload(b64_data="image1"), ImagePayload(b64_data="image2")]) + response2 = ImageGenerationResponse( + images=[ImagePayload(b64_data="image3"), ImagePayload(b64_data="image4"), ImagePayload(b64_data="image5")] ) - # Second call - 3 images - mock_response2 = lazy.litellm.types.utils.ImageResponse( - data=[ - lazy.litellm.types.utils.ImageObject(b64_json="image3"), - lazy.litellm.types.utils.ImageObject(b64_json="image4"), - lazy.litellm.types.utils.ImageObject(b64_json="image5"), - ] - ) - mock_image_generation.side_effect = [mock_response1, mock_response2] + stub_model_client.generate_image.side_effect = [response1, response2] - # Verify initial state assert stub_model_facade.usage_stats.image_usage.total_images == 0 - # First generation with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): images1 = stub_model_facade.generate_image(prompt="test1") assert len(images1) == 2 assert stub_model_facade.usage_stats.image_usage.total_images == 2 - # Second generation images2 = stub_model_facade.generate_image(prompt="test2") assert len(images2) == 3 - # Usage should accumulate assert stub_model_facade.usage_stats.image_usage.total_images == 5 @@ -1295,52 +1119,43 @@ def test_generate_image_accumulates_usage( True, ], ) -@patch.object(CustomRouter, "acompletion", new_callable=AsyncMock) @pytest.mark.asyncio async def test_acompletion_success( - mock_router_acompletion: AsyncMock, stub_completion_messages: list[ChatMessage], stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_completion_response: ModelResponse, + stub_model_client: MagicMock, skip_usage_tracking: bool, ) -> None: - mock_router_acompletion.return_value = stub_expected_completion_response + expected_response = _make_response("Test response") + stub_model_client.acompletion = AsyncMock(return_value=expected_response) result = await stub_model_facade.acompletion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking) - expected_messages = [message.to_dict() for message in stub_completion_messages] - assert result == stub_expected_completion_response - assert mock_router_acompletion.call_count == 1 - assert mock_router_acompletion.call_args[1] == { - "model": "stub-model-text", - "messages": expected_messages, - **stub_model_configs[0].inference_parameters.generate_kwargs, - } + assert result == expected_response + assert stub_model_client.acompletion.call_count == 1 -@patch.object(CustomRouter, "acompletion", new_callable=AsyncMock) @pytest.mark.asyncio async def test_acompletion_with_exception( - mock_router_acompletion: AsyncMock, stub_completion_messages: list[ChatMessage], stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: - mock_router_acompletion.side_effect = Exception("Router error") + stub_model_client.acompletion = AsyncMock(side_effect=Exception("Router error")) with pytest.raises(Exception, match="Router error"): await stub_model_facade.acompletion(stub_completion_messages) -@patch.object(CustomRouter, "aembedding", new_callable=AsyncMock) @pytest.mark.asyncio async def test_agenerate_text_embeddings_success( - mock_router_aembedding: AsyncMock, stub_model_facade: ModelFacade, - stub_expected_embedding_response: EmbeddingResponse, + stub_model_client: MagicMock, ) -> None: - mock_router_aembedding.return_value = stub_expected_embedding_response + expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + stub_model_client.aembeddings = AsyncMock(return_value=EmbeddingResponse(vectors=expected_vectors)) input_texts = ["test1", "test2"] result = await stub_model_facade.agenerate_text_embeddings(input_texts) - assert result == [data["embedding"] for data in stub_expected_embedding_response.data] + assert result == expected_vectors @pytest.mark.parametrize( @@ -1363,7 +1178,7 @@ async def test_agenerate_correction_retries( max_conversation_restarts: int, total_calls: int, ) -> None: - bad_response = mock_oai_response_object("bad response") + bad_response = _make_response("bad response") mock_acompletion.return_value = bad_response def _failing_parser(response: str) -> str: @@ -1396,13 +1211,12 @@ async def test_agenerate_success( mock_acompletion: AsyncMock, stub_model_facade: ModelFacade, ) -> None: - good_response = mock_oai_response_object("parsed output") + good_response = _make_response("parsed output") mock_acompletion.return_value = good_response result, trace = await stub_model_facade.agenerate(prompt="test", parser=lambda x: x) assert result == "parsed output" assert mock_acompletion.call_count == 1 - # Trace should contain at least the user prompt and the assistant response assert any(msg.role == "user" for msg in trace) assert any(msg.role == "assistant" and msg.content == "parsed output" for msg in trace) @@ -1412,105 +1226,52 @@ async def test_agenerate_success( # ============================================================================= -@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock) @pytest.mark.asyncio async def test_agenerate_image_diffusion_success( - mock_aimage_generation: AsyncMock, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test async image generation via diffusion API.""" - mock_response = lazy.litellm.types.utils.ImageResponse( - data=[ - lazy.litellm.types.utils.ImageObject(b64_json="image1_base64"), - lazy.litellm.types.utils.ImageObject(b64_json="image2_base64"), - ] + stub_model_client.agenerate_image = AsyncMock( + return_value=ImageGenerationResponse( + images=[ImagePayload(b64_data="image1_base64"), ImagePayload(b64_data="image2_base64")] + ) ) - mock_aimage_generation.return_value = mock_response with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): images = await stub_model_facade.agenerate_image(prompt="test prompt") assert len(images) == 2 assert images == ["image1_base64", "image2_base64"] - assert mock_aimage_generation.call_count == 1 - # Verify image usage was tracked assert stub_model_facade.usage_stats.image_usage.total_images == 2 -@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) @pytest.mark.asyncio async def test_agenerate_image_chat_completion_success( - mock_acompletion: AsyncMock, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test async image generation via chat completion API.""" - mock_message = lazy.litellm.types.utils.Message( - role="assistant", - content="", - images=[ - lazy.litellm.types.utils.ImageURLListItem( - type="image_url", image_url={"url": "data:image/png;base64,image1"}, index=0 - ), - ], + stub_model_client.agenerate_image = AsyncMock( + return_value=ImageGenerationResponse(images=[ImagePayload(b64_data="image1")]) ) - mock_response = lazy.litellm.types.utils.ModelResponse( - choices=[lazy.litellm.types.utils.Choices(message=mock_message)] - ) - mock_acompletion.return_value = mock_response with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): images = await stub_model_facade.agenerate_image(prompt="test prompt") assert len(images) == 1 assert images == ["image1"] - assert mock_acompletion.call_count == 1 assert stub_model_facade.usage_stats.image_usage.total_images == 1 -@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock) @pytest.mark.asyncio -async def test_agenerate_image_diffusion_no_data( - mock_aimage_generation: AsyncMock, +async def test_agenerate_image_no_data( stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: - """Test async image generation raises error when diffusion API returns no data.""" - mock_response = lazy.litellm.types.utils.ImageResponse(data=[]) - mock_aimage_generation.return_value = mock_response + """Test async image generation raises error when no data.""" + stub_model_client.agenerate_image = AsyncMock(return_value=ImageGenerationResponse(images=[])) with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): - with pytest.raises(ImageGenerationError, match="Image generation returned no data"): + with pytest.raises(ImageGenerationError, match="No image data found"): await stub_model_facade.agenerate_image(prompt="test prompt") - - -@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) -@pytest.mark.asyncio -async def test_agenerate_image_chat_completion_no_choices( - mock_acompletion: AsyncMock, - stub_model_facade: ModelFacade, -) -> None: - """Test async image generation raises error when response has no choices.""" - mock_response = lazy.litellm.types.utils.ModelResponse(choices=[]) - mock_acompletion.return_value = mock_response - - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): - with pytest.raises(ImageGenerationError, match="Image generation response missing choices"): - await stub_model_facade.agenerate_image(prompt="test prompt") - - -@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock) -@pytest.mark.asyncio -async def test_agenerate_image_skip_usage_tracking( - mock_aimage_generation: AsyncMock, - stub_model_facade: ModelFacade, -) -> None: - """Test that async image generation respects skip_usage_tracking flag.""" - mock_response = lazy.litellm.types.utils.ImageResponse( - data=[lazy.litellm.types.utils.ImageObject(b64_json="image1_base64")] - ) - mock_aimage_generation.return_value = mock_response - - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): - images = await stub_model_facade.agenerate_image(prompt="test prompt", skip_usage_tracking=True) - - assert len(images) == 1 - assert stub_model_facade.usage_stats.image_usage.total_images == 0 From b8579c269e75c9c4ede11615e97668cc9b44ff0b Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 5 Mar 2026 10:39:34 -0700 Subject: [PATCH 20/27] small refactor --- .../src/data_designer/engine/mcp/facade.py | 24 +- .../engine/models/clients/__init__.py | 2 + .../models/clients/adapters/__init__.py | 2 + .../models/clients/adapters/litellm_bridge.py | 34 +- .../engine/models/clients/parsing.py | 15 - .../engine/models/clients/types.py | 56 ++- .../src/data_designer/engine/models/facade.py | 460 +++++++++--------- .../src/data_designer/engine/testing/stubs.py | 2 +- .../tests/engine/mcp/test_mcp_facade.py | 10 +- .../models/clients/test_litellm_bridge.py | 17 +- .../engine/models/clients/test_parsing.py | 161 ++++++ .../tests/engine/models/conftest.py | 2 + 12 files changed, 509 insertions(+), 276 deletions(-) create mode 100644 packages/data-designer-engine/tests/engine/models/clients/test_parsing.py diff --git a/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py b/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py index 89289ad5b..2adcb8d4b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py @@ -72,7 +72,7 @@ def timeout_sec(self) -> float | None: return self._tool_config.timeout_sec @staticmethod - def tool_call_count(completion_response: ChatCompletionResponse) -> int: + def get_tool_call_count(completion_response: ChatCompletionResponse) -> int: """Count the number of tool calls in a completion response.""" return len(completion_response.message.tool_calls) @@ -81,14 +81,6 @@ def has_tool_calls(completion_response: ChatCompletionResponse) -> bool: """Returns True if tool calls are present in the completion response.""" return len(completion_response.message.tool_calls) > 0 - def _resolve_provider(self, provider: MCPProviderT) -> MCPProviderT: - """Resolve secret references in an MCP provider's api_key.""" - api_key_ref = getattr(provider, "api_key", None) - if not api_key_ref: - return provider - resolved_key = self._secret_resolver.resolve(api_key_ref) - return provider.model_copy(update={"api_key": resolved_key}) - def get_tool_schemas(self) -> list[dict[str, Any]]: """Get OpenAI-compatible tool schemas for this configuration. @@ -182,7 +174,7 @@ def process_completion_response( ] # Has tool calls - execute and return all messages - tool_call_dicts = _canonical_tool_calls_to_dicts(tool_calls) + tool_call_dicts = _convert_canonical_tool_calls_to_dicts(tool_calls) assistant_message = self._build_assistant_tool_message(response_content, tool_call_dicts, reasoning_content) tool_messages = self._execute_tool_calls_from_canonical(tool_calls) @@ -227,7 +219,7 @@ def refuse_completion_response( ] # Build assistant message with tool calls (same as normal) - tool_call_dicts = _canonical_tool_calls_to_dicts(tool_calls) + tool_call_dicts = _convert_canonical_tool_calls_to_dicts(tool_calls) assistant_message = self._build_assistant_tool_message(response_content, tool_call_dicts, reasoning_content) # Build refusal messages instead of executing tools @@ -236,6 +228,14 @@ def refuse_completion_response( return [assistant_message, *tool_messages] + def _resolve_provider(self, provider: MCPProviderT) -> MCPProviderT: + """Resolve secret references in an MCP provider's api_key.""" + api_key_ref = getattr(provider, "api_key", None) + if not api_key_ref: + return provider + resolved_key = self._secret_resolver.resolve(api_key_ref) + return provider.model_copy(update={"api_key": resolved_key}) + def _build_assistant_tool_message( self, response: str | None, @@ -299,7 +299,7 @@ def _find_resolved_provider_for_tool(self, tool_name: str) -> MCPProviderT: raise MCPConfigurationError(f"Tool {tool_name!r} not found on any configured provider.") -def _canonical_tool_calls_to_dicts(tool_calls: list[ToolCall]) -> list[dict[str, Any]]: +def _convert_canonical_tool_calls_to_dicts(tool_calls: list[ToolCall]) -> list[dict[str, Any]]: """Convert canonical ToolCall objects to the internal dict format for ChatMessage.""" return [ { diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index 99312e4b8..fc72e1e4b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from data_designer.engine.models.clients.base import ModelClient from data_designer.engine.models.clients.errors import ( ProviderError, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py index 1b65e2dde..cc9feefea 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/__init__.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient, LiteLLMRouter __all__ = ["LiteLLMBridgeClient", "LiteLLMRouter"] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index 45d615031..5623529a9 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -17,7 +17,6 @@ from data_designer.engine.models.clients.parsing import ( aextract_images_from_chat_response, aextract_images_from_image_response, - collect_non_none_optional_fields, extract_embedding_vector, extract_images_from_chat_response, extract_images_from_image_response, @@ -31,6 +30,7 @@ EmbeddingResponse, ImageGenerationRequest, ImageGenerationResponse, + TransportKwargs, ) logger = logging.getLogger(__name__) @@ -74,58 +74,68 @@ def supports_image_generation(self) -> bool: return True def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + transport = TransportKwargs.from_request(request) with _handle_non_provider_errors(self.provider_name): response = self._router.completion( model=request.model, messages=request.messages, - **collect_non_none_optional_fields(request), + extra_headers=transport.headers or None, + **transport.body, ) return parse_chat_completion_response(response) async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + transport = TransportKwargs.from_request(request) with _handle_non_provider_errors(self.provider_name): response = await self._router.acompletion( model=request.model, messages=request.messages, - **collect_non_none_optional_fields(request), + extra_headers=transport.headers or None, + **transport.body, ) return parse_chat_completion_response(response) def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + transport = TransportKwargs.from_request(request) with _handle_non_provider_errors(self.provider_name): response = self._router.embedding( model=request.model, input=request.inputs, - **collect_non_none_optional_fields(request), + extra_headers=transport.headers or None, + **transport.body, ) vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])] return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response) async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + transport = TransportKwargs.from_request(request) with _handle_non_provider_errors(self.provider_name): response = await self._router.aembedding( model=request.model, input=request.inputs, - **collect_non_none_optional_fields(request), + extra_headers=transport.headers or None, + **transport.body, ) vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])] return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response) def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) + transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) with _handle_non_provider_errors(self.provider_name): if request.messages is not None: response = self._router.completion( model=request.model, messages=request.messages, - **image_kwargs, + extra_headers=transport.headers or None, + **transport.body, ) images = extract_images_from_chat_response(response) else: response = self._router.image_generation( prompt=request.prompt, model=request.model, - **image_kwargs, + extra_headers=transport.headers or None, + **transport.body, ) images = extract_images_from_image_response(response) @@ -133,20 +143,22 @@ def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResp return ImageGenerationResponse(images=images, usage=usage, raw=response) async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) + transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) with _handle_non_provider_errors(self.provider_name): if request.messages is not None: response = await self._router.acompletion( model=request.model, messages=request.messages, - **image_kwargs, + extra_headers=transport.headers or None, + **transport.body, ) images = await aextract_images_from_chat_response(response) else: response = await self._router.aimage_generation( prompt=request.prompt, model=request.model, - **image_kwargs, + extra_headers=transport.headers or None, + **transport.body, ) images = await aextract_images_from_image_response(response) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py index 315d5d4e7..c4aa5276f 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -5,7 +5,6 @@ from __future__ import annotations -import dataclasses import json import logging from typing import Any @@ -312,17 +311,3 @@ def get_first_value_or_none(values: Any) -> Any | None: if isinstance(values, list) and values: return values[0] return None - - -def collect_non_none_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: - """Extract non-None optional fields from a request dataclass, skipping *exclude*. - - The ``f.default is None`` check intentionally targets fields whose default is - ``None`` — i.e. truly optional kwargs the caller may or may not set. Fields with - non-``None`` defaults are not "optional" in this forwarding sense and are excluded. - """ - return { - f.name: v - for f in dataclasses.fields(request) - if f.name not in exclude and f.default is None and (v := getattr(request, f.name)) is not None - } diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py index 3df379910..63766a738 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py @@ -3,8 +3,8 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any, Protocol +from dataclasses import dataclass, field, fields +from typing import Any, ClassVar, Protocol class HttpResponse(Protocol): @@ -101,3 +101,55 @@ class ImageGenerationResponse: images: list[ImagePayload] usage: Usage | None = None raw: Any | None = None + + +# --------------------------------------------------------------------------- +# Transport preparation +# --------------------------------------------------------------------------- + + +@dataclass +class TransportKwargs: + """Pre-processed kwargs ready for an HTTP client call. + + Adapters call ``TransportKwargs.from_request(request)`` instead of + manually handling ``extra_body`` / ``extra_headers`` on every request type. + + - ``body``: API-level keyword arguments with ``extra_body`` keys merged + into the top level (mirroring how LiteLLM flattens them). + - ``headers``: Extra HTTP headers to attach to the outgoing request. + """ + + _META_FIELDS: ClassVar[frozenset[str]] = frozenset({"extra_body", "extra_headers"}) + + body: dict[str, Any] + headers: dict[str, str] + + @classmethod + def from_request(cls, request: Any, *, exclude: frozenset[str] = frozenset()) -> TransportKwargs: + """Build transport-ready kwargs from a canonical request dataclass. + + 1. Collects all non-None optional fields (respecting *exclude*). + 2. Pops ``extra_body`` and merges its keys into the top-level body dict. + 3. Pops ``extra_headers`` into a separate headers dict. + """ + fields = cls._collect_optional_fields(request, exclude=exclude | cls._META_FIELDS) + + extra_body = getattr(request, "extra_body", None) or {} + extra_headers = getattr(request, "extra_headers", None) or {} + + return cls(body={**fields, **extra_body}, headers=dict(extra_headers)) + + @staticmethod + def _collect_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: + """Extract non-None optional fields from a request dataclass, skipping *exclude*. + + Targets fields whose default is ``None`` — i.e. truly optional kwargs + the caller may or may not set. Fields with non-``None`` defaults are + not "optional" in this forwarding sense and are excluded. + """ + return { + f.name: v + for f in fields(request) + if f.name not in exclude and f.default is None and (v := getattr(request, f.name)) is not None + } diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 91545ecb2..d03a7a357 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -95,6 +95,18 @@ def max_parallel_requests(self) -> int: def usage_stats(self) -> ModelUsageStats: return self._usage_stats + def consolidate_kwargs(self, **kwargs: Any) -> dict[str, Any]: + # Remove purpose from kwargs to avoid passing it to the model + kwargs.pop("purpose", None) + kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} + if self.model_provider.extra_body: + kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} + if self.model_provider.extra_headers: + kwargs["extra_headers"] = self.model_provider.extra_headers + return kwargs + + # --- completion / acompletion --- + def completion( self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any ) -> ChatCompletionResponse: @@ -122,15 +134,34 @@ def completion( if not skip_usage_tracking and response is not None: self._track_usage(response.usage, is_request_successful=True) - def consolidate_kwargs(self, **kwargs: Any) -> dict[str, Any]: - # Remove purpose from kwargs to avoid passing it to the model - kwargs.pop("purpose", None) - kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} - if self.model_provider.extra_body: - kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} - if self.model_provider.extra_headers: - kwargs["extra_headers"] = self.model_provider.extra_headers - return kwargs + async def acompletion( + self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any + ) -> ChatCompletionResponse: + message_payloads = [message.to_dict() for message in messages] + logger.debug( + f"Prompting model {self.model_name!r}...", + extra={"model": self.model_name, "messages": message_payloads}, + ) + response = None + kwargs = self.consolidate_kwargs(**kwargs) + try: + request = self._build_chat_completion_request(message_payloads, kwargs) + response = await self._client.acompletion(request) + logger.debug( + f"Received completion from model {self.model_name!r}", + extra={ + "model": self.model_name, + "response": response, + "text": response.message.content, + "usage": self._usage_stats.model_dump(), + }, + ) + return response + finally: + if not skip_usage_tracking and response is not None: + self._track_usage(response.usage, is_request_successful=True) + + # --- generate / agenerate --- @catch_llm_exceptions def generate( @@ -227,7 +258,7 @@ def generate( # Process any tool calls in the response (handles parallel tool calling) if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response): tool_call_turns += 1 - total_tool_calls += mcp_facade.tool_call_count(completion_response) + total_tool_calls += mcp_facade.get_tool_call_count(completion_response) if tool_call_turns > mcp_facade.max_tool_call_turns: # Gracefully refuse tool calls when budget is exhausted @@ -280,6 +311,105 @@ def generate( return output_obj, messages + @acatch_llm_exceptions + async def agenerate( + self, + prompt: str, + *, + parser: Callable[[str], Any] = _identity, + system_prompt: str | None = None, + multi_modal_context: list[dict[str, Any]] | None = None, + tool_alias: str | None = None, + max_correction_steps: int = 0, + max_conversation_restarts: int = 0, + skip_usage_tracking: bool = False, + purpose: str | None = None, + **kwargs: Any, + ) -> tuple[Any, list[ChatMessage]]: + output_obj = None + tool_schemas = None + tool_call_turns = 0 + total_tool_calls = 0 + curr_num_correction_steps = 0 + curr_num_restarts = 0 + + mcp_facade = self._get_mcp_facade(tool_alias) + + restart_checkpoint = prompt_to_messages( + user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context + ) + checkpoint_tool_call_turns = 0 + messages: list[ChatMessage] = deepcopy(restart_checkpoint) + + if mcp_facade is not None: + tool_schemas = await asyncio.to_thread(mcp_facade.get_tool_schemas) + + while True: + completion_kwargs = dict(kwargs) + if tool_schemas is not None: + completion_kwargs["tools"] = tool_schemas + + completion_response = await self.acompletion( + messages, + skip_usage_tracking=skip_usage_tracking, + **completion_kwargs, + ) + + if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response): + tool_call_turns += 1 + total_tool_calls += mcp_facade.get_tool_call_count(completion_response) + + if tool_call_turns > mcp_facade.max_tool_call_turns: + messages.extend(mcp_facade.refuse_completion_response(completion_response)) + else: + messages.extend( + await asyncio.to_thread(mcp_facade.process_completion_response, completion_response) + ) + + restart_checkpoint = deepcopy(messages) + checkpoint_tool_call_turns = tool_call_turns + + continue + + response = (completion_response.message.content or "").strip() + reasoning_trace = completion_response.message.reasoning_content + messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) + curr_num_correction_steps += 1 + + try: + output_obj = parser(response) + break + except ParserException as exc: + if max_correction_steps == 0 and max_conversation_restarts == 0: + raise GenerationValidationFailureError( + "Unsuccessful generation attempt. No retries were attempted." + ) from exc + + if curr_num_correction_steps <= max_correction_steps: + messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc)))) + + elif curr_num_restarts < max_conversation_restarts: + curr_num_correction_steps = 0 + curr_num_restarts += 1 + messages = deepcopy(restart_checkpoint) + tool_call_turns = checkpoint_tool_call_turns + + else: + raise GenerationValidationFailureError( + f"Unsuccessful generation despite {max_correction_steps} correction steps " + f"and {max_conversation_restarts} conversation restarts." + ) from exc + + if not skip_usage_tracking and mcp_facade is not None: + self._usage_stats.tool_usage.extend( + tool_calls=total_tool_calls, + tool_call_turns=tool_call_turns, + ) + + return output_obj, messages + + # --- generate_text_embeddings / agenerate_text_embeddings --- + @catch_llm_exceptions def generate_text_embeddings( self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs: Any @@ -311,6 +441,39 @@ def generate_text_embeddings( if not skip_usage_tracking and response is not None: self._track_usage(response.usage, is_request_successful=True) + @acatch_llm_exceptions + async def agenerate_text_embeddings( + self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs: Any + ) -> list[list[float]]: + logger.debug( + f"Generating embeddings with model {self.model_name!r}...", + extra={ + "model": self.model_name, + "input_count": len(input_texts), + }, + ) + kwargs = self.consolidate_kwargs(**kwargs) + response: EmbeddingResponse | None = None + try: + request = self._build_embedding_request(input_texts, kwargs) + response = await self._client.aembeddings(request) + logger.debug( + f"Received embeddings from model {self.model_name!r}", + extra={ + "model": self.model_name, + "embedding_count": len(response.vectors), + "usage": self._usage_stats.model_dump(), + }, + ) + if len(response.vectors) == len(input_texts): + return response.vectors + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.vectors)}") + finally: + if not skip_usage_tracking and response is not None: + self._track_usage(response.usage, is_request_successful=True) + + # --- generate_image / agenerate_image --- + @catch_llm_exceptions def generate_image( self, @@ -355,13 +518,74 @@ def generate_image( if not images: raise ImageGenerationError("No image data found in image generation response") - # Track image usage if not skip_usage_tracking and images: self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) self._track_usage(response.usage, is_request_successful=True) return images + @acatch_llm_exceptions + async def agenerate_image( + self, + prompt: str, + multi_modal_context: list[dict[str, Any]] | None = None, + skip_usage_tracking: bool = False, + **kwargs: Any, + ) -> list[str]: + """Async version of generate_image. Generate image(s) and return base64-encoded data. + + Automatically detects the appropriate API based on model name: + - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) -> image_generation API + - All other models -> chat/completions API (default) + + Both paths return base64-encoded image data. If the API returns multiple images, + all are returned in the list. + + Args: + prompt: The prompt for image generation + multi_modal_context: Optional list of image contexts for multi-modal generation. + Only used with autoregressive models via chat completions API. + skip_usage_tracking: Whether to skip usage tracking + **kwargs: Additional arguments to pass to the model (including n=number of images) + + Returns: + List of base64-encoded image strings (without data URI prefix) + + Raises: + ImageGenerationError: If image generation fails or returns invalid data + """ + logger.debug( + f"Generating image with model {self.model_name!r}...", + extra={"model": self.model_name, "prompt": prompt}, + ) + + kwargs = self.consolidate_kwargs(**kwargs) + request = self._build_image_generation_request(prompt, multi_modal_context, kwargs) + response = await self._client.agenerate_image(request) + + images = [img.b64_data for img in response.images] + + if not images: + raise ImageGenerationError("No image data found in image generation response") + + if not skip_usage_tracking and images: + self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) + self._track_usage(response.usage, is_request_successful=True) + + return images + + # --- close / aclose --- + + def close(self) -> None: + """Release resources held by the underlying client.""" + self._client.close() + + async def aclose(self) -> None: + """Async release resources held by the underlying client.""" + await self._client.aclose() + + # --- private helpers --- + def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: if tool_alias is None: return None @@ -450,217 +674,3 @@ def _track_usage(self, usage: Usage | None, *, is_request_successful: bool) -> N token_usage=token_usage, request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), ) - - async def acompletion( - self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any - ) -> ChatCompletionResponse: - message_payloads = [message.to_dict() for message in messages] - logger.debug( - f"Prompting model {self.model_name!r}...", - extra={"model": self.model_name, "messages": message_payloads}, - ) - response = None - kwargs = self.consolidate_kwargs(**kwargs) - try: - request = self._build_chat_completion_request(message_payloads, kwargs) - response = await self._client.acompletion(request) - logger.debug( - f"Received completion from model {self.model_name!r}", - extra={ - "model": self.model_name, - "response": response, - "text": response.message.content, - "usage": self._usage_stats.model_dump(), - }, - ) - return response - finally: - if not skip_usage_tracking and response is not None: - self._track_usage(response.usage, is_request_successful=True) - - @acatch_llm_exceptions - async def agenerate_text_embeddings( - self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs: Any - ) -> list[list[float]]: - logger.debug( - f"Generating embeddings with model {self.model_name!r}...", - extra={ - "model": self.model_name, - "input_count": len(input_texts), - }, - ) - kwargs = self.consolidate_kwargs(**kwargs) - response: EmbeddingResponse | None = None - try: - request = self._build_embedding_request(input_texts, kwargs) - response = await self._client.aembeddings(request) - logger.debug( - f"Received embeddings from model {self.model_name!r}", - extra={ - "model": self.model_name, - "embedding_count": len(response.vectors), - "usage": self._usage_stats.model_dump(), - }, - ) - if len(response.vectors) == len(input_texts): - return response.vectors - raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.vectors)}") - finally: - if not skip_usage_tracking and response is not None: - self._track_usage(response.usage, is_request_successful=True) - - @acatch_llm_exceptions - async def agenerate( - self, - prompt: str, - *, - parser: Callable[[str], Any] = _identity, - system_prompt: str | None = None, - multi_modal_context: list[dict[str, Any]] | None = None, - tool_alias: str | None = None, - max_correction_steps: int = 0, - max_conversation_restarts: int = 0, - skip_usage_tracking: bool = False, - purpose: str | None = None, - **kwargs: Any, - ) -> tuple[Any, list[ChatMessage]]: - output_obj = None - tool_schemas = None - tool_call_turns = 0 - total_tool_calls = 0 - curr_num_correction_steps = 0 - curr_num_restarts = 0 - - mcp_facade = self._get_mcp_facade(tool_alias) - - restart_checkpoint = prompt_to_messages( - user_prompt=prompt, system_prompt=system_prompt, multi_modal_context=multi_modal_context - ) - checkpoint_tool_call_turns = 0 - messages: list[ChatMessage] = deepcopy(restart_checkpoint) - - if mcp_facade is not None: - tool_schemas = await asyncio.to_thread(mcp_facade.get_tool_schemas) - - while True: - completion_kwargs = dict(kwargs) - if tool_schemas is not None: - completion_kwargs["tools"] = tool_schemas - - completion_response = await self.acompletion( - messages, - skip_usage_tracking=skip_usage_tracking, - **completion_kwargs, - ) - - if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response): - tool_call_turns += 1 - total_tool_calls += mcp_facade.tool_call_count(completion_response) - - if tool_call_turns > mcp_facade.max_tool_call_turns: - messages.extend(mcp_facade.refuse_completion_response(completion_response)) - else: - messages.extend( - await asyncio.to_thread(mcp_facade.process_completion_response, completion_response) - ) - - restart_checkpoint = deepcopy(messages) - checkpoint_tool_call_turns = tool_call_turns - - continue - - response = (completion_response.message.content or "").strip() - reasoning_trace = completion_response.message.reasoning_content - messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) - curr_num_correction_steps += 1 - - try: - output_obj = parser(response) - break - except ParserException as exc: - if max_correction_steps == 0 and max_conversation_restarts == 0: - raise GenerationValidationFailureError( - "Unsuccessful generation attempt. No retries were attempted." - ) from exc - - if curr_num_correction_steps <= max_correction_steps: - messages.append(ChatMessage.as_user(content=str(get_exception_primary_cause(exc)))) - - elif curr_num_restarts < max_conversation_restarts: - curr_num_correction_steps = 0 - curr_num_restarts += 1 - messages = deepcopy(restart_checkpoint) - tool_call_turns = checkpoint_tool_call_turns - - else: - raise GenerationValidationFailureError( - f"Unsuccessful generation despite {max_correction_steps} correction steps " - f"and {max_conversation_restarts} conversation restarts." - ) from exc - - if not skip_usage_tracking and mcp_facade is not None: - self._usage_stats.tool_usage.extend( - tool_calls=total_tool_calls, - tool_call_turns=tool_call_turns, - ) - - return output_obj, messages - - @acatch_llm_exceptions - async def agenerate_image( - self, - prompt: str, - multi_modal_context: list[dict[str, Any]] | None = None, - skip_usage_tracking: bool = False, - **kwargs: Any, - ) -> list[str]: - """Async version of generate_image. Generate image(s) and return base64-encoded data. - - Automatically detects the appropriate API based on model name: - - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) -> image_generation API - - All other models -> chat/completions API (default) - - Both paths return base64-encoded image data. If the API returns multiple images, - all are returned in the list. - - Args: - prompt: The prompt for image generation - multi_modal_context: Optional list of image contexts for multi-modal generation. - Only used with autoregressive models via chat completions API. - skip_usage_tracking: Whether to skip usage tracking - **kwargs: Additional arguments to pass to the model (including n=number of images) - - Returns: - List of base64-encoded image strings (without data URI prefix) - - Raises: - ImageGenerationError: If image generation fails or returns invalid data - """ - logger.debug( - f"Generating image with model {self.model_name!r}...", - extra={"model": self.model_name, "prompt": prompt}, - ) - - kwargs = self.consolidate_kwargs(**kwargs) - request = self._build_image_generation_request(prompt, multi_modal_context, kwargs) - response = await self._client.agenerate_image(request) - - images = [img.b64_data for img in response.images] - - if not images: - raise ImageGenerationError("No image data found in image generation response") - - # Track image usage - if not skip_usage_tracking and images: - self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) - self._track_usage(response.usage, is_request_successful=True) - - return images - - def close(self) -> None: - """Release resources held by the underlying client.""" - self._client.close() - - async def aclose(self) -> None: - """Async release resources held by the underlying client.""" - await self._client.aclose() diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py b/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py index 4807b64c8..ce4c2fb5a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py +++ b/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py @@ -229,7 +229,7 @@ def __init__( def get_tool_schemas(self) -> list[dict[str, Any]]: return self._tool_schemas - def tool_call_count(self, completion_response: ChatCompletionResponse) -> int: + def get_tool_call_count(self, completion_response: ChatCompletionResponse) -> int: return len(completion_response.message.tool_calls) def has_tool_calls(self, completion_response: ChatCompletionResponse) -> bool: diff --git a/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py b/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py index 5dae4fdbc..1300ed550 100644 --- a/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py +++ b/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py @@ -33,29 +33,29 @@ def _make_response( # ============================================================================= -# tool_call_count() tests +# get_tool_call_count() tests # ============================================================================= def test_tool_call_count_no_tools(mock_completion_response_no_tools: ChatCompletionResponse) -> None: """Returns 0 when response has no tool calls.""" - assert MCPFacade.tool_call_count(mock_completion_response_no_tools) == 0 + assert MCPFacade.get_tool_call_count(mock_completion_response_no_tools) == 0 def test_tool_call_count_single_tool(mock_completion_response_single_tool: ChatCompletionResponse) -> None: """Returns 1 for single tool call.""" - assert MCPFacade.tool_call_count(mock_completion_response_single_tool) == 1 + assert MCPFacade.get_tool_call_count(mock_completion_response_single_tool) == 1 def test_tool_call_count_parallel_tools(mock_completion_response_parallel_tools: ChatCompletionResponse) -> None: """Returns correct count for parallel tool calls (e.g., 3).""" - assert MCPFacade.tool_call_count(mock_completion_response_parallel_tools) == 3 + assert MCPFacade.get_tool_call_count(mock_completion_response_parallel_tools) == 3 def test_tool_call_count_none_tool_calls_attribute() -> None: """Returns 0 when tool_calls is empty.""" response = _make_response(content="Hello") - assert MCPFacade.tool_call_count(response) == 0 + assert MCPFacade.get_tool_call_count(response) == 0 # ============================================================================= diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py index 7c9b8db9a..f368536b1 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py @@ -58,13 +58,13 @@ def test_completion_maps_canonical_fields_from_litellm_response( mock_router.completion.assert_called_once_with( model="stub-model", messages=[{"role": "user", "content": "hello"}], + extra_headers={"x-trace": "1"}, tools=[{"type": "function", "function": {"name": "lookup"}}], temperature=0.2, top_p=0.8, max_tokens=256, - extra_body={"foo": "bar"}, - extra_headers={"x-trace": "1"}, metadata={"trace_id": "abc"}, + foo="bar", ) @@ -84,6 +84,7 @@ async def test_acompletion_maps_canonical_fields_from_litellm_response( mock_router.acompletion.assert_awaited_once_with( model="stub-model", messages=[{"role": "user", "content": "hello"}], + extra_headers=None, ) @@ -108,6 +109,7 @@ def test_embeddings_maps_vectors_and_usage( mock_router.embedding.assert_called_once_with( model="stub-model", input=["a", "b"], + extra_headers=None, encoding_format="float", dimensions=32, ) @@ -148,6 +150,7 @@ def test_generate_image_uses_chat_completion_path_when_messages_provided( mock_router.completion.assert_called_once_with( model="stub-model", messages=messages, + extra_headers=None, n=1, ) mock_router.image_generation.assert_not_called() @@ -176,7 +179,9 @@ def test_generate_image_uses_diffusion_path_without_messages( assert result.usage.output_tokens == 12 assert result.usage.total_tokens == 21 assert result.usage.generated_images == 2 - mock_router.image_generation.assert_called_once_with(prompt="make an image", model="stub-model", n=2) + mock_router.image_generation.assert_called_once_with( + prompt="make an image", model="stub-model", extra_headers=None, n=2 + ) @pytest.mark.asyncio @@ -197,7 +202,7 @@ async def test_aembeddings_maps_vectors_and_usage( assert result.usage is not None assert result.usage.input_tokens == 5 assert result.raw is response - mock_router.aembedding.assert_awaited_once_with(model="stub-model", input=["x", "y"]) + mock_router.aembedding.assert_awaited_once_with(model="stub-model", input=["x", "y"], extra_headers=None) def test_completion_coerces_list_content_blocks_to_string( @@ -245,7 +250,9 @@ async def test_agenerate_image_uses_diffusion_path_without_messages( assert result.images[0].b64_data == "YXN5bmM=" assert result.usage is not None assert result.usage.generated_images == 1 - mock_router.aimage_generation.assert_awaited_once_with(prompt="async image", model="stub-model", n=1) + mock_router.aimage_generation.assert_awaited_once_with( + prompt="async image", model="stub-model", extra_headers=None, n=1 + ) def test_completion_with_empty_choices_returns_empty_message( diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py new file mode 100644 index 000000000..3fa48348d --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py @@ -0,0 +1,161 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from data_designer.engine.models.clients.types import ( + ChatCompletionRequest, + EmbeddingRequest, + ImageGenerationRequest, + TransportKwargs, +) + +# --- TransportKwargs.from_request: extra_body flattening --- + + +def test_extra_body_keys_are_flattened_into_body() -> None: + request = ChatCompletionRequest( + model="m", + messages=[], + temperature=0.7, + extra_body={"reasoning_effort": "high", "seed": 42}, + ) + transport = TransportKwargs.from_request(request) + + assert transport.body["temperature"] == 0.7 + assert transport.body["reasoning_effort"] == "high" + assert transport.body["seed"] == 42 + assert "extra_body" not in transport.body + + +def test_extra_body_none_produces_no_extra_keys() -> None: + request = ChatCompletionRequest(model="m", messages=[], temperature=0.5) + transport = TransportKwargs.from_request(request) + + assert transport.body == {"temperature": 0.5} + assert "extra_body" not in transport.body + + +def test_extra_body_empty_dict_produces_no_extra_keys() -> None: + request = ChatCompletionRequest(model="m", messages=[], extra_body={}) + transport = TransportKwargs.from_request(request) + + assert "extra_body" not in transport.body + + +# --- TransportKwargs.from_request: extra_headers separation --- + + +def test_extra_headers_are_separated_into_headers() -> None: + request = ChatCompletionRequest( + model="m", + messages=[], + extra_headers={"X-Custom": "value", "Authorization": "Bearer tok"}, + ) + transport = TransportKwargs.from_request(request) + + assert transport.headers == {"X-Custom": "value", "Authorization": "Bearer tok"} + assert "extra_headers" not in transport.body + + +def test_extra_headers_none_produces_empty_headers() -> None: + request = ChatCompletionRequest(model="m", messages=[]) + transport = TransportKwargs.from_request(request) + + assert transport.headers == {} + + +# --- TransportKwargs.from_request: combined --- + + +def test_extra_body_and_headers_together() -> None: + request = ChatCompletionRequest( + model="m", + messages=[], + temperature=0.9, + max_tokens=100, + extra_body={"seed": 1}, + extra_headers={"X-Req-Id": "abc"}, + ) + transport = TransportKwargs.from_request(request) + + assert transport.body == {"temperature": 0.9, "max_tokens": 100, "seed": 1} + assert transport.headers == {"X-Req-Id": "abc"} + + +# --- TransportKwargs.from_request: exclude parameter --- + + +def test_exclude_removes_fields_from_body() -> None: + request = ImageGenerationRequest( + model="m", + prompt="draw a cat", + messages=[{"role": "user", "content": "hi"}], + n=2, + extra_body={"quality": "hd"}, + ) + transport = TransportKwargs.from_request(request, exclude=frozenset({"messages", "prompt"})) + + assert "messages" not in transport.body + assert "prompt" not in transport.body + assert transport.body["n"] == 2 + assert transport.body["quality"] == "hd" + + +# --- TransportKwargs.from_request: works with all request types --- + + +def test_embedding_request() -> None: + request = EmbeddingRequest( + model="m", + inputs=["hello"], + extra_body={"input_type": "query"}, + extra_headers={"X-Api-Version": "2"}, + ) + transport = TransportKwargs.from_request(request) + + assert transport.body["input_type"] == "query" + assert transport.headers == {"X-Api-Version": "2"} + assert "extra_body" not in transport.body + assert "extra_headers" not in transport.body + + +def test_image_generation_request() -> None: + request = ImageGenerationRequest( + model="m", + prompt="sunset", + n=3, + extra_body={"size": "1024x1024"}, + ) + transport = TransportKwargs.from_request(request) + + assert transport.body["n"] == 3 + assert transport.body["size"] == "1024x1024" + assert transport.headers == {} + + +# --- TransportKwargs: falsy headers --- + + +def test_transport_kwargs_empty_headers_is_falsy() -> None: + tk = TransportKwargs(body={"a": 1}, headers={}) + assert not tk.headers + + +@pytest.mark.parametrize( + ("extra_body", "expected_body_keys"), + [ + (None, set()), + ({}, set()), + ({"a": 1}, {"a"}), + ({"a": 1, "b": 2}, {"a", "b"}), + ], +) +def test_extra_body_variations(extra_body: dict | None, expected_body_keys: set[str]) -> None: + request = ChatCompletionRequest(model="m", messages=[], extra_body=extra_body) + transport = TransportKwargs.from_request(request) + + assert expected_body_keys.issubset(transport.body.keys()) + assert "extra_body" not in transport.body diff --git a/packages/data-designer-engine/tests/engine/models/conftest.py b/packages/data-designer-engine/tests/engine/models/conftest.py index 965d815e4..3f065217b 100644 --- a/packages/data-designer-engine/tests/engine/models/conftest.py +++ b/packages/data-designer-engine/tests/engine/models/conftest.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from pathlib import Path from unittest.mock import MagicMock From 61024c00923e874d74013439ef02121223b95be3 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 5 Mar 2026 12:24:15 -0700 Subject: [PATCH 21/27] address feedback --- .../models/clients/adapters/litellm_bridge.py | 17 ++++++--- .../engine/models/clients/parsing.py | 35 +++++++++++++++---- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index 45d615031..a7bc73b63 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -17,6 +17,7 @@ from data_designer.engine.models.clients.parsing import ( aextract_images_from_chat_response, aextract_images_from_image_response, + aparse_chat_completion_response, collect_non_none_optional_fields, extract_embedding_vector, extract_images_from_chat_response, @@ -89,7 +90,7 @@ async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionRes messages=request.messages, **collect_non_none_optional_fields(request), ) - return parse_chat_completion_response(response) + return await aparse_chat_completion_response(response) def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: with _handle_non_provider_errors(self.provider_name): @@ -120,14 +121,17 @@ def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResp messages=request.messages, **image_kwargs, ) - images = extract_images_from_chat_response(response) else: response = self._router.image_generation( prompt=request.prompt, model=request.model, **image_kwargs, ) - images = extract_images_from_image_response(response) + + if request.messages is not None: + images = extract_images_from_chat_response(response) + else: + images = extract_images_from_image_response(response) usage = extract_usage(getattr(response, "usage", None), generated_images=len(images)) return ImageGenerationResponse(images=images, usage=usage, raw=response) @@ -141,14 +145,17 @@ async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerat messages=request.messages, **image_kwargs, ) - images = await aextract_images_from_chat_response(response) else: response = await self._router.aimage_generation( prompt=request.prompt, model=request.model, **image_kwargs, ) - images = await aextract_images_from_image_response(response) + + if request.messages is not None: + images = await aextract_images_from_chat_response(response) + else: + images = await aextract_images_from_image_response(response) usage = extract_usage(getattr(response, "usage", None), generated_images=len(images)) return ImageGenerationResponse(images=images, usage=usage, raw=response) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py index 315d5d4e7..b565ee87b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -47,6 +47,21 @@ def parse_chat_completion_response(response: Any) -> ChatCompletionResponse: return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) +async def aparse_chat_completion_response(response: Any) -> ChatCompletionResponse: + first_choice = get_first_value_or_none(getattr(response, "choices", None)) + message = get_value_from(first_choice, "message") + tool_calls = extract_tool_calls(get_value_from(message, "tool_calls")) + images = await aextract_images_from_chat_message(message) + assistant_message = AssistantMessage( + content=coerce_message_content(get_value_from(message, "content")), + reasoning_content=get_value_from(message, "reasoning_content"), + tool_calls=tool_calls, + images=images, + ) + usage = extract_usage(getattr(response, "usage", None), generated_images=len(images) if images else None) + return ChatCompletionResponse(message=assistant_message, usage=usage, raw=response) + + # --------------------------------------------------------------------------- # Image extraction # --------------------------------------------------------------------------- @@ -124,7 +139,7 @@ def parse_image_payload(raw_image: Any) -> ImagePayload | None: return ImagePayload(b64_data=load_image_url_to_base64(result), mime_type=None) return result except Exception: - logger.debug("Unable to parse image payload from response object.", exc_info=True) + logger.warning("Failed to parse image payload from response object; image dropped.", exc_info=True) return None @@ -135,7 +150,7 @@ async def aparse_image_payload(raw_image: Any) -> ImagePayload | None: return ImagePayload(b64_data=await aload_image_url_to_base64(result), mime_type=None) return result except Exception: - logger.debug("Unable to parse image payload from response object.", exc_info=True) + logger.warning("Failed to parse image payload from response object; image dropped.", exc_info=True) return None @@ -230,7 +245,11 @@ def extract_usage(raw_usage: Any, generated_images: int | None = None) -> Usage if output_tokens is None: output_tokens = get_value_from(raw_usage, "output_tokens") - if total_tokens is None and isinstance(input_tokens, int) and isinstance(output_tokens, int): + input_tokens = coerce_to_int_or_none(input_tokens) + output_tokens = coerce_to_int_or_none(output_tokens) + total_tokens = coerce_to_int_or_none(total_tokens) + + if total_tokens is None and input_tokens is not None and output_tokens is not None: total_tokens = input_tokens + output_tokens if generated_images is None: @@ -238,14 +257,16 @@ def extract_usage(raw_usage: Any, generated_images: int | None = None) -> Usage if generated_images is None and raw_usage is not None: generated_images = get_value_from(raw_usage, "images") + generated_images = coerce_to_int_or_none(generated_images) + if input_tokens is None and output_tokens is None and total_tokens is None and generated_images is None: return None return Usage( - input_tokens=coerce_to_int_or_none(input_tokens), - output_tokens=coerce_to_int_or_none(output_tokens), - total_tokens=coerce_to_int_or_none(total_tokens), - generated_images=coerce_to_int_or_none(generated_images), + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + generated_images=generated_images, ) From 49a45bad8eb9a0be7ff2f70c3c5da79301e621cd Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 5 Mar 2026 13:49:15 -0700 Subject: [PATCH 22/27] Address greptile comment in pr1 --- .../models/clients/adapters/litellm_bridge.py | 3 +- .../engine/models/clients/errors.py | 26 +++++++++++ .../models/clients/test_client_errors.py | 44 +++++++++++++++++++ 3 files changed, 72 insertions(+), 1 deletion(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index 69c43cf87..f5b861b4e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -12,6 +12,7 @@ from data_designer.engine.models.clients.errors import ( ProviderError, ProviderErrorKind, + extract_message_from_exception_string, map_http_status_to_provider_error_kind, ) from data_designer.engine.models.clients.parsing import ( @@ -195,7 +196,7 @@ def _handle_non_provider_errors(provider_name: str) -> Iterator[None]: raise ProviderError( kind=kind, - message=str(exc), + message=extract_message_from_exception_string(str(exc)), status_code=status_code if isinstance(status_code, int) else None, provider_name=provider_name, cause=exc, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index 8e8a3b0ac..359aaa51c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -5,6 +5,7 @@ import calendar import email.utils +import json import time from dataclasses import dataclass from enum import Enum @@ -118,6 +119,31 @@ def map_http_error_to_provider_error( ) +def extract_message_from_exception_string(raw: str) -> str: + """Extract a human-readable message from a stringified LiteLLM exception. + + LiteLLM often formats errors as ``"Error code: 400 - {json}"``. This + mirrors the structured-key lookup in ``_extract_structured_message`` but + operates on a raw string instead of an ``HttpResponse``. + """ + json_start = raw.find("{") + if json_start != -1: + try: + payload = json.loads(raw[json_start:]) + except (json.JSONDecodeError, ValueError): + return raw + if isinstance(payload, dict): + for key in ("message", "error", "detail"): + value = payload.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + if isinstance(value, dict): + nested = value.get("message") + if isinstance(nested, str) and nested.strip(): + return nested.strip() + return raw + + def _extract_response_text(response: HttpResponse) -> str: # Try structured JSON extraction first — most providers return structured error # bodies and we want the human-readable message, not raw JSON. diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py index 194884f13..62828039d 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py @@ -10,6 +10,7 @@ from data_designer.engine.models.clients.errors import ( ProviderError, ProviderErrorKind, + extract_message_from_exception_string, map_http_error_to_provider_error, map_http_status_to_provider_error_kind, ) @@ -207,3 +208,46 @@ def test_map_http_error_retry_after_returns_none_for_garbage() -> None: ) error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") assert error.retry_after is None + + +@pytest.mark.parametrize( + "raw,expected", + [ + ( + "Error code: 400 - {'error': {'message': 'Context length exceeded', 'type': 'invalid_request_error'}}".replace( + "'", '"' + ), + "Context length exceeded", + ), + ( + 'Error code: 403 - {"error": "Insufficient permissions"}', + "Insufficient permissions", + ), + ( + 'Error code: 500 - {"message": "Internal failure"}', + "Internal failure", + ), + ( + 'Error code: 422 - {"detail": "Unprocessable entity"}', + "Unprocessable entity", + ), + ( + "Connection timed out", + "Connection timed out", + ), + ( + "Error code: 400 - {not valid json", + "Error code: 400 - {not valid json", + ), + ], + ids=[ + "nested-error-message", + "top-level-error-string", + "top-level-message-string", + "top-level-detail-string", + "no-json-passthrough", + "malformed-json-passthrough", + ], +) +def test_extract_message_from_exception_string(raw: str, expected: str) -> None: + assert extract_message_from_exception_string(raw) == expected From e8445cc783dd19f9a175d7877c53d63f4a177815 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Thu, 5 Mar 2026 16:45:57 -0700 Subject: [PATCH 23/27] refactor ProviderError from dataclass to regular Exception - Replace @dataclass + __post_init__ with explicit __init__ that calls super().__init__ properly, avoiding brittle field-ordering dependency - Store cause via __cause__ only, removing the redundant .cause attr - Update match pattern in handle_llm_exceptions for non-dataclass type - Rename shadowed local `fields` to `optional_fields` in TransportKwargs --- .../engine/models/clients/errors.py | 33 +++++++++++-------- .../engine/models/clients/types.py | 4 +-- .../src/data_designer/engine/models/errors.py | 10 ++++-- .../models/clients/test_litellm_bridge.py | 2 +- 4 files changed, 30 insertions(+), 19 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index 359aaa51c..da3c19383 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -7,7 +7,6 @@ import email.utils import json import time -from dataclasses import dataclass from enum import Enum from data_designer.engine.models.clients.types import HttpResponse @@ -29,20 +28,26 @@ class ProviderErrorKind(str, Enum): UNSUPPORTED_CAPABILITY = "unsupported_capability" -@dataclass class ProviderError(Exception): - kind: ProviderErrorKind - message: str - status_code: int | None = None - provider_name: str | None = None - model_name: str | None = None - retry_after: float | None = None - cause: Exception | None = None - - def __post_init__(self) -> None: - Exception.__init__(self, self.message) - if self.cause is not None: - self.__cause__ = self.cause + def __init__( + self, + kind: ProviderErrorKind, + message: str, + status_code: int | None = None, + provider_name: str | None = None, + model_name: str | None = None, + retry_after: float | None = None, + cause: Exception | None = None, + ) -> None: + super().__init__(message) + self.kind = kind + self.message = message + self.status_code = status_code + self.provider_name = provider_name + self.model_name = model_name + self.retry_after = retry_after + if cause is not None: + self.__cause__ = cause def __str__(self) -> str: return self.message diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py index 63766a738..f4f345b36 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py @@ -133,12 +133,12 @@ def from_request(cls, request: Any, *, exclude: frozenset[str] = frozenset()) -> 2. Pops ``extra_body`` and merges its keys into the top-level body dict. 3. Pops ``extra_headers`` into a separate headers dict. """ - fields = cls._collect_optional_fields(request, exclude=exclude | cls._META_FIELDS) + optional_fields = cls._collect_optional_fields(request, exclude=exclude | cls._META_FIELDS) extra_body = getattr(request, "extra_body", None) or {} extra_headers = getattr(request, "extra_headers", None) or {} - return cls(body={**fields, **extra_body}, headers=dict(extra_headers)) + return cls(body={**optional_fields, **extra_body}, headers=dict(extra_headers)) @staticmethod def _collect_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index 7a98f40d7..3c56f954b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -125,9 +125,15 @@ def handle_llm_exceptions( err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose) match exception: # Canonical ProviderError from the client adapter layer - case ProviderError(kind=kind): + case ProviderError(): _raise_from_provider_error( - exception, kind, model_name, model_provider_name, purpose, authentication_error, err_msg_parser + exception, + exception.kind, + model_name, + model_provider_name, + purpose, + authentication_error, + err_msg_parser, ) # LiteLLM-specific errors (safety net during bridge period) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py index f368536b1..c95e7c070 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py @@ -307,7 +307,7 @@ def test_completion_wraps_router_exception_with_status_code( assert exc_info.value.kind == ProviderErrorKind.RATE_LIMIT assert exc_info.value.status_code == 429 assert exc_info.value.provider_name == "stub-provider" - assert exc_info.value.cause is exc + assert exc_info.value.__cause__ is exc def test_completion_wraps_generic_router_exception( From 521c1e4bd550858e4cf8e0830dc709303ba7e650 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 6 Mar 2026 09:39:27 -0700 Subject: [PATCH 24/27] Address greptile feedback --- .../src/data_designer/engine/mcp/facade.py | 3 +- .../engine/models/clients/types.py | 6 +++ .../src/data_designer/engine/models/facade.py | 49 +++++++++++++++---- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py b/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py index 2adcb8d4b..10eecffc4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py @@ -270,7 +270,8 @@ def _execute_tool_calls_from_canonical( providers_str = ", ".join(repr(p) for p in self._tool_config.providers) raise MCPToolError(f"Tool {tc.name!r} is not permitted for providers: {providers_str}.") - arguments = json.loads(tc.arguments_json) if tc.arguments_json else {} + arguments_raw = json.loads(tc.arguments_json) if tc.arguments_json else {} + arguments = arguments_raw if isinstance(arguments_raw, dict) else {} resolved_provider = self._find_resolved_provider_for_tool(tc.name) calls_to_execute.append((resolved_provider, tc.name, arguments, tc.id)) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py index f4f345b36..d83c7a121 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py @@ -54,6 +54,12 @@ class ChatCompletionRequest: temperature: float | None = None top_p: float | None = None max_tokens: int | None = None + stop: str | list[str] | None = None + seed: int | None = None + response_format: dict[str, Any] | None = None + frequency_penalty: float | None = None + presence_penalty: float | None = None + n: int | None = None timeout: float | None = None extra_body: dict[str, Any] | None = None extra_headers: dict[str, str] | None = None diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index d03a7a357..0e55f018c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -48,7 +48,21 @@ def _identity(x: Any) -> Any: # Known keyword arguments extracted into ChatCompletionRequest fields. _COMPLETION_REQUEST_FIELDS = frozenset( - {"temperature", "top_p", "max_tokens", "timeout", "tools", "extra_body", "extra_headers"} + { + "temperature", + "top_p", + "max_tokens", + "stop", + "seed", + "response_format", + "frequency_penalty", + "presence_penalty", + "n", + "timeout", + "tools", + "extra_body", + "extra_headers", + } ) @@ -131,8 +145,11 @@ def completion( ) return response finally: - if not skip_usage_tracking and response is not None: - self._track_usage(response.usage, is_request_successful=True) + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) async def acompletion( self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any @@ -158,8 +175,11 @@ async def acompletion( ) return response finally: - if not skip_usage_tracking and response is not None: - self._track_usage(response.usage, is_request_successful=True) + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) # --- generate / agenerate --- @@ -438,8 +458,11 @@ def generate_text_embeddings( return response.vectors raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.vectors)}") finally: - if not skip_usage_tracking and response is not None: - self._track_usage(response.usage, is_request_successful=True) + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) @acatch_llm_exceptions async def agenerate_text_embeddings( @@ -469,8 +492,11 @@ async def agenerate_text_embeddings( return response.vectors raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.vectors)}") finally: - if not skip_usage_tracking and response is not None: - self._track_usage(response.usage, is_request_successful=True) + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) # --- generate_image / agenerate_image --- @@ -611,6 +637,11 @@ def _build_chat_completion_request( metadata[key] = value if metadata: + logger.debug( + "Unknown kwargs %s routed to LiteLLM metadata (not forwarded as model parameters). " + "Use 'extra_body' to pass non-standard parameters to the model.", + sorted(metadata.keys()), + ) request_fields["metadata"] = metadata return ChatCompletionRequest(**request_fields) From 4836e0394104c9468c5142b47da185abe4c7ffcd Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 6 Mar 2026 10:04:48 -0700 Subject: [PATCH 25/27] PR feedback --- .../src/data_designer/engine/models/errors.py | 2 -- .../src/data_designer/engine/models/facade.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index 3c56f954b..88280eeb2 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -133,7 +133,6 @@ def handle_llm_exceptions( model_provider_name, purpose, authentication_error, - err_msg_parser, ) # LiteLLM-specific errors (safety net during bridge period) @@ -347,7 +346,6 @@ def _raise_from_provider_error( model_provider_name: str, purpose: str, authentication_error: FormattedLLMErrorMessage, - err_msg_parser: DownstreamLLMExceptionMessageParser, ) -> None: """Map a canonical ProviderError to the appropriate DataDesignerError subclass.""" _KIND_MAP: dict[ProviderErrorKind, type[DataDesignerError]] = { diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 0e55f018c..28b045886 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -651,6 +651,8 @@ def _build_embedding_request(self, input_texts: list[str], kwargs: dict[str, Any return EmbeddingRequest( model=self.model_name, inputs=input_texts, + encoding_format=kwargs.get("encoding_format"), + dimensions=kwargs.get("dimensions"), timeout=kwargs.get("timeout"), extra_body=kwargs.get("extra_body"), extra_headers=kwargs.get("extra_headers"), From ae1bf9845b76b924164ca3690744d7cfdab59bcc Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 6 Mar 2026 10:48:46 -0700 Subject: [PATCH 26/27] track usage tracking in finally block for images --- .../src/data_designer/engine/models/facade.py | 51 ++++++++++++------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 28b045886..dd9b37e89 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -19,6 +19,7 @@ EmbeddingRequest, EmbeddingResponse, ImageGenerationRequest, + ImageGenerationResponse, Usage, ) from data_designer.engine.models.errors import ( @@ -536,19 +537,26 @@ def generate_image( ) kwargs = self.consolidate_kwargs(**kwargs) - request = self._build_image_generation_request(prompt, multi_modal_context, kwargs) - response = self._client.generate_image(request) + response: ImageGenerationResponse | None = None + try: + request = self._build_image_generation_request(prompt, multi_modal_context, kwargs) + response = self._client.generate_image(request) - images = [img.b64_data for img in response.images] + images = [img.b64_data for img in response.images] - if not images: - raise ImageGenerationError("No image data found in image generation response") + if not images: + raise ImageGenerationError("No image data found in image generation response") - if not skip_usage_tracking and images: - self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) - self._track_usage(response.usage, is_request_successful=True) + if not skip_usage_tracking: + self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) - return images + return images + finally: + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) @acatch_llm_exceptions async def agenerate_image( @@ -586,19 +594,26 @@ async def agenerate_image( ) kwargs = self.consolidate_kwargs(**kwargs) - request = self._build_image_generation_request(prompt, multi_modal_context, kwargs) - response = await self._client.agenerate_image(request) + response: ImageGenerationResponse | None = None + try: + request = self._build_image_generation_request(prompt, multi_modal_context, kwargs) + response = await self._client.agenerate_image(request) - images = [img.b64_data for img in response.images] + images = [img.b64_data for img in response.images] - if not images: - raise ImageGenerationError("No image data found in image generation response") + if not images: + raise ImageGenerationError("No image data found in image generation response") - if not skip_usage_tracking and images: - self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) - self._track_usage(response.usage, is_request_successful=True) + if not skip_usage_tracking: + self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) - return images + return images + finally: + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) # --- close / aclose --- From 18b9966ba5c66cd8acd9a86f6a7708f20c432ce1 Mon Sep 17 00:00:00 2001 From: Nabin Mulepati Date: Fri, 6 Mar 2026 11:27:32 -0700 Subject: [PATCH 27/27] pr feedback --- .../engine/models/clients/parsing.py | 3 +- .../src/data_designer/engine/models/errors.py | 4 +- .../engine/models/clients/test_parsing.py | 54 +++++++++++++++++++ 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py index 675646c2a..e5d74d440 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -7,6 +7,7 @@ import json import logging +import uuid from typing import Any from data_designer.config.utils.image_helpers import ( @@ -205,7 +206,7 @@ def extract_tool_calls(raw_tool_calls: Any) -> list[ToolCall]: normalized_tool_calls: list[ToolCall] = [] for raw_tool_call in raw_tool_calls: - tool_call_id = get_value_from(raw_tool_call, "id") or "" + tool_call_id = get_value_from(raw_tool_call, "id") or uuid.uuid4().hex function = get_value_from(raw_tool_call, "function") name = get_value_from(function, "name") or "" arguments_value = get_value_from(function, "arguments") diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index 88280eeb2..6ad084fa7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -6,7 +6,7 @@ import logging from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NoReturn from pydantic import BaseModel @@ -346,7 +346,7 @@ def _raise_from_provider_error( model_provider_name: str, purpose: str, authentication_error: FormattedLLMErrorMessage, -) -> None: +) -> NoReturn: """Map a canonical ProviderError to the appropriate DataDesignerError subclass.""" _KIND_MAP: dict[ProviderErrorKind, type[DataDesignerError]] = { ProviderErrorKind.RATE_LIMIT: ModelRateLimitError, diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py index 3fa48348d..d48502431 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py @@ -5,6 +5,7 @@ import pytest +from data_designer.engine.models.clients.parsing import extract_tool_calls from data_designer.engine.models.clients.types import ( ChatCompletionRequest, EmbeddingRequest, @@ -159,3 +160,56 @@ def test_extra_body_variations(extra_body: dict | None, expected_body_keys: set[ assert expected_body_keys.issubset(transport.body.keys()) assert "extra_body" not in transport.body + + +# --- extract_tool_calls --- + + +def _make_raw_tool_call( + tool_id: str | None = "call-1", + name: str = "lookup", + arguments: str = '{"q": "test"}', +) -> dict: + tc: dict = {"type": "function", "function": {"name": name, "arguments": arguments}} + if tool_id is not None: + tc["id"] = tool_id + return tc + + +def test_extract_tool_calls_basic() -> None: + raw = [_make_raw_tool_call()] + result = extract_tool_calls(raw) + + assert len(result) == 1 + assert result[0].id == "call-1" + assert result[0].name == "lookup" + assert result[0].arguments_json == '{"q": "test"}' + + +@pytest.mark.parametrize("tool_id", [None, ""], ids=["missing_id", "empty_string_id"]) +def test_extract_tool_calls_falsy_id_generates_uuid(tool_id: str | None) -> None: + raw = [_make_raw_tool_call(tool_id=tool_id)] + result = extract_tool_calls(raw) + + assert len(result) == 1 + assert len(result[0].id) == 32 # uuid4().hex length + assert result[0].id.isalnum() + + +def test_extract_tool_calls_multiple_missing_ids_are_unique() -> None: + raw = [_make_raw_tool_call(tool_id=None), _make_raw_tool_call(tool_id=None)] + result = extract_tool_calls(raw) + + assert result[0].id != result[1].id + + +@pytest.mark.parametrize("raw_input", [None, []], ids=["none", "empty_list"]) +def test_extract_tool_calls_empty_input(raw_input: list | None) -> None: + assert extract_tool_calls(raw_input) == [] + + +def test_extract_tool_calls_none_arguments() -> None: + raw = [{"id": "call-1", "function": {"name": "lookup", "arguments": None}}] + result = extract_tool_calls(raw) + + assert result[0].arguments_json == "{}"