Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
ab30a2d
plans for model facade overhaul
nabinchha Feb 19, 2026
43824ea
update plan
nabinchha Feb 20, 2026
2a5f1e4
add review
johnnygreco Feb 20, 2026
f945d5b
address feedback + add more details after several self reviews
nabinchha Feb 20, 2026
dfa3817
update plan doc
nabinchha Feb 25, 2026
5b18f74
Merge branch 'main' into nm/overhaul-model-facade-guts
nabinchha Feb 25, 2026
0f449a7
address nits
nabinchha Feb 25, 2026
37f092a
Merge branch 'nm/overhaul-model-facade-guts' into nm/overhaul-model-f…
nabinchha Feb 25, 2026
08e57f8
Add cannonical objects
nabinchha Feb 26, 2026
3ab18ee
Merge branch 'main' into nm/overhaul-model-facade-guts-pr1
nabinchha Feb 27, 2026
34349c7
self-review feedback + address
nabinchha Feb 28, 2026
6aae4b6
add LiteLLMRouter protocol to strongly type bridge router param
nabinchha Feb 28, 2026
2a53d37
simplify some things
nabinchha Feb 28, 2026
4e2f3af
add a protol for http response like object
nabinchha Feb 28, 2026
b1c85f2
move HttpResponse
nabinchha Feb 28, 2026
f6dc769
update PR-1 architecture notes for lifecycle and router protocol
nabinchha Feb 28, 2026
ec5ed9b
Address PR #359 feedback: exception wrapping, shared parsing, test im…
nabinchha Mar 4, 2026
b6b4028
Merge branch 'main' into nm/overhaul-model-facade-guts-pr1
nabinchha Mar 4, 2026
ba22397
Use contextlib to dry out some code
nabinchha Mar 4, 2026
aeac3b9
Address Greptile feedback: HTTP-date retry-after parsing, docstring c…
nabinchha Mar 4, 2026
55f3c96
Address Greptile feedback: FastAPI detail parsing, comment fixes
nabinchha Mar 4, 2026
c390912
Merge branch 'main' into nm/overhaul-model-facade-guts-pr1
nabinchha Mar 4, 2026
828cc49
add PR-2 architecture notes for model facade overhaul
nabinchha Mar 4, 2026
89a6d4e
save progress on pr2
nabinchha Mar 5, 2026
e527503
Merge branch 'main' into nm/overhaul-model-facade-guts-pr1
nabinchha Mar 5, 2026
f6fa447
Merge branch 'nm/overhaul-model-facade-guts-pr1' into nm/overhaul-mod…
nabinchha Mar 5, 2026
b8579c2
small refactor
nabinchha Mar 5, 2026
61024c0
address feedback
nabinchha Mar 5, 2026
d47d508
Merge branch 'nm/overhaul-model-facade-guts-pr1' into nm/overhaul-mod…
nabinchha Mar 5, 2026
49a45ba
Address greptile comment in pr1
nabinchha Mar 5, 2026
e8445cc
refactor ProviderError from dataclass to regular Exception
nabinchha Mar 5, 2026
8a385ff
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha Mar 6, 2026
521c1e4
Address greptile feedback
nabinchha Mar 6, 2026
a831d24
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha Mar 6, 2026
4836e03
PR feedback
nabinchha Mar 6, 2026
ae1bf98
track usage tracking in finally block for images
nabinchha Mar 6, 2026
18b9966
pr feedback
nabinchha Mar 6, 2026
25650b0
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha Mar 6, 2026
651813b
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha Mar 6, 2026
bfed5af
Merge branch 'main' into nm/overhaul-model-facade-guts-pr2
nabinchha Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,6 @@ packages/data-designer/README.md
.cursor/rules/cerebro.mdc
.cursor/mcp.json
.claude/rules/cerebro.md

# Claude worktrees
.claude/worktrees/
285 changes: 56 additions & 229 deletions packages/data-designer-engine/src/data_designer/engine/mcp/facade.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
map_http_error_to_provider_error,
map_http_status_to_provider_error_kind,
)
from data_designer.engine.models.clients.factory import create_model_client
from data_designer.engine.models.clients.types import (
AssistantMessage,
ChatCompletionRequest,
Expand All @@ -25,12 +26,12 @@
)

__all__ = [
"HttpResponse",
"AssistantMessage",
"ChatCompletionRequest",
"ChatCompletionResponse",
"EmbeddingRequest",
"EmbeddingResponse",
"HttpResponse",
"ImageGenerationRequest",
"ImageGenerationResponse",
"ImagePayload",
Expand All @@ -39,6 +40,7 @@
"ProviderErrorKind",
"ToolCall",
"Usage",
"create_model_client",
"map_http_error_to_provider_error",
"map_http_status_to_provider_error_kind",
]
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
from data_designer.engine.models.clients.errors import (
ProviderError,
ProviderErrorKind,
extract_message_from_exception_string,
map_http_status_to_provider_error_kind,
)
from data_designer.engine.models.clients.parsing import (
aextract_images_from_chat_response,
aextract_images_from_image_response,
aparse_chat_completion_response,
collect_non_none_optional_fields,
extract_embedding_vector,
extract_images_from_chat_response,
extract_images_from_image_response,
Expand All @@ -32,6 +32,7 @@
EmbeddingResponse,
ImageGenerationRequest,
ImageGenerationResponse,
TransportKwargs,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -75,57 +76,67 @@ def supports_image_generation(self) -> bool:
return True

def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
transport = TransportKwargs.from_request(request)
with _handle_non_provider_errors(self.provider_name):
response = self._router.completion(
model=request.model,
messages=request.messages,
**collect_non_none_optional_fields(request),
extra_headers=transport.headers or None,
**transport.body,
)
return parse_chat_completion_response(response)

async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
transport = TransportKwargs.from_request(request)
with _handle_non_provider_errors(self.provider_name):
response = await self._router.acompletion(
model=request.model,
messages=request.messages,
**collect_non_none_optional_fields(request),
extra_headers=transport.headers or None,
**transport.body,
)
return await aparse_chat_completion_response(response)

def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
transport = TransportKwargs.from_request(request)
with _handle_non_provider_errors(self.provider_name):
response = self._router.embedding(
model=request.model,
input=request.inputs,
**collect_non_none_optional_fields(request),
extra_headers=transport.headers or None,
**transport.body,
)
vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])]
return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response)

async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
transport = TransportKwargs.from_request(request)
with _handle_non_provider_errors(self.provider_name):
response = await self._router.aembedding(
model=request.model,
input=request.inputs,
**collect_non_none_optional_fields(request),
extra_headers=transport.headers or None,
**transport.body,
)
vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])]
return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response)

def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse:
image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE)
transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE)
with _handle_non_provider_errors(self.provider_name):
if request.messages is not None:
response = self._router.completion(
model=request.model,
messages=request.messages,
**image_kwargs,
extra_headers=transport.headers or None,
**transport.body,
)
else:
response = self._router.image_generation(
prompt=request.prompt,
model=request.model,
**image_kwargs,
extra_headers=transport.headers or None,
**transport.body,
)

if request.messages is not None:
Expand All @@ -137,19 +148,21 @@ def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResp
return ImageGenerationResponse(images=images, usage=usage, raw=response)

async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse:
image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE)
transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE)
with _handle_non_provider_errors(self.provider_name):
if request.messages is not None:
response = await self._router.acompletion(
model=request.model,
messages=request.messages,
**image_kwargs,
extra_headers=transport.headers or None,
**transport.body,
)
else:
response = await self._router.aimage_generation(
prompt=request.prompt,
model=request.model,
**image_kwargs,
extra_headers=transport.headers or None,
**transport.body,
)

if request.messages is not None:
Expand Down Expand Up @@ -183,7 +196,7 @@ def _handle_non_provider_errors(provider_name: str) -> Iterator[None]:

raise ProviderError(
kind=kind,
message=str(exc),
message=extract_message_from_exception_string(str(exc)),
status_code=status_code if isinstance(status_code, int) else None,
provider_name=provider_name,
cause=exc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import calendar
import email.utils
import json
import time
from dataclasses import dataclass
from enum import Enum

from data_designer.engine.models.clients.types import HttpResponse
Expand All @@ -28,20 +28,26 @@ class ProviderErrorKind(str, Enum):
UNSUPPORTED_CAPABILITY = "unsupported_capability"


@dataclass
class ProviderError(Exception):
kind: ProviderErrorKind
message: str
status_code: int | None = None
provider_name: str | None = None
model_name: str | None = None
retry_after: float | None = None
cause: Exception | None = None

def __post_init__(self) -> None:
Exception.__init__(self, self.message)
if self.cause is not None:
self.__cause__ = self.cause
def __init__(
self,
kind: ProviderErrorKind,
message: str,
status_code: int | None = None,
provider_name: str | None = None,
model_name: str | None = None,
retry_after: float | None = None,
cause: Exception | None = None,
) -> None:
super().__init__(message)
self.kind = kind
self.message = message
self.status_code = status_code
self.provider_name = provider_name
self.model_name = model_name
self.retry_after = retry_after
if cause is not None:
self.__cause__ = cause

def __str__(self) -> str:
return self.message
Expand Down Expand Up @@ -118,6 +124,31 @@ def map_http_error_to_provider_error(
)


def extract_message_from_exception_string(raw: str) -> str:
"""Extract a human-readable message from a stringified LiteLLM exception.

LiteLLM often formats errors as ``"Error code: 400 - {json}"``. This
mirrors the structured-key lookup in ``_extract_structured_message`` but
operates on a raw string instead of an ``HttpResponse``.
"""
json_start = raw.find("{")
if json_start != -1:
try:
payload = json.loads(raw[json_start:])
except (json.JSONDecodeError, ValueError):
return raw
if isinstance(payload, dict):
for key in ("message", "error", "detail"):
value = payload.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
if isinstance(value, dict):
nested = value.get("message")
if isinstance(nested, str) and nested.strip():
return nested.strip()
return raw


def _extract_response_text(response: HttpResponse) -> str:
# Try structured JSON extraction first β€” most providers return structured error
# bodies and we want the human-readable message, not raw JSON.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import data_designer.lazy_heavy_imports as lazy
from data_designer.config.models import ModelConfig
from data_designer.engine.model_provider import ModelProviderRegistry
from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient
from data_designer.engine.models.clients.base import ModelClient
from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs
from data_designer.engine.secret_resolver import SecretResolver


def create_model_client(
model_config: ModelConfig,
secret_resolver: SecretResolver,
model_provider_registry: ModelProviderRegistry,
) -> ModelClient:
"""Create a ModelClient for the given model configuration.

Resolves the provider, API key, and constructs a LiteLLM router wrapped in
a LiteLLMBridgeClient adapter.

Args:
model_config: The model configuration to create a client for.
secret_resolver: Resolver for secrets referenced in provider configs.
model_provider_registry: Registry of model provider configurations.

Returns:
A ModelClient instance ready for use.
"""
provider = model_provider_registry.get_provider(model_config.provider)
api_key = None
if provider.api_key:
api_key = secret_resolver.resolve(provider.api_key)
api_key = api_key or "not-used-but-required"

litellm_params = lazy.litellm.LiteLLM_Params(
model=f"{provider.provider_type}/{model_config.model}",
api_base=provider.endpoint,
api_key=api_key,
max_parallel_requests=model_config.inference_parameters.max_parallel_requests,
)
deployment = {
"model_name": model_config.model,
"litellm_params": litellm_params.model_dump(),
}
router = CustomRouter([deployment], **LiteLLMRouterDefaultKwargs().model_dump())
return LiteLLMBridgeClient(provider_name=provider.name, router=router)
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from __future__ import annotations

import dataclasses
import json
import logging
import uuid
from typing import Any

from data_designer.config.utils.image_helpers import (
Expand Down Expand Up @@ -206,7 +206,7 @@ def extract_tool_calls(raw_tool_calls: Any) -> list[ToolCall]:

normalized_tool_calls: list[ToolCall] = []
for raw_tool_call in raw_tool_calls:
tool_call_id = get_value_from(raw_tool_call, "id") or ""
tool_call_id = get_value_from(raw_tool_call, "id") or uuid.uuid4().hex
function = get_value_from(raw_tool_call, "function")
name = get_value_from(function, "name") or ""
arguments_value = get_value_from(function, "arguments")
Expand Down Expand Up @@ -333,17 +333,3 @@ def get_first_value_or_none(values: Any) -> Any | None:
if isinstance(values, list) and values:
return values[0]
return None


def collect_non_none_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]:
"""Extract non-None optional fields from a request dataclass, skipping *exclude*.

The ``f.default is None`` check intentionally targets fields whose default is
``None`` β€” i.e. truly optional kwargs the caller may or may not set. Fields with
non-``None`` defaults are not "optional" in this forwarding sense and are excluded.
"""
return {
f.name: v
for f in dataclasses.fields(request)
if f.name not in exclude and f.default is None and (v := getattr(request, f.name)) is not None
}
Loading