diff --git a/.gitignore b/.gitignore index 9f3b78f66..520c5aff9 100644 --- a/.gitignore +++ b/.gitignore @@ -107,3 +107,6 @@ packages/data-designer/README.md .cursor/rules/cerebro.mdc .cursor/mcp.json .claude/rules/cerebro.md + +# Claude worktrees +.claude/worktrees/ diff --git a/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py b/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py index f49069e18..10eecffc4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/mcp/facade.py @@ -4,7 +4,6 @@ from __future__ import annotations import json -import uuid from typing import Any from data_designer.config.mcp import MCPProviderT, ToolConfig @@ -12,6 +11,7 @@ from data_designer.engine.mcp.errors import DuplicateToolNameError, MCPConfigurationError, MCPToolError from data_designer.engine.mcp.registry import MCPToolDefinition from data_designer.engine.model_provider import MCPProviderRegistry +from data_designer.engine.models.clients.types import ChatCompletionResponse, ToolCall from data_designer.engine.models.utils import ChatMessage from data_designer.engine.secret_resolver import SecretResolver @@ -38,13 +38,6 @@ def __init__( secret_resolver: SecretResolver, mcp_provider_registry: MCPProviderRegistry, ) -> None: - """Initialize the MCPFacade. - - Args: - tool_config: The tool configuration this facade is scoped to. - secret_resolver: Resolver for secrets referenced in provider configs. - mcp_provider_registry: Registry of MCP provider configurations. - """ self._tool_config = tool_config self._secret_resolver = secret_resolver self._mcp_provider_registry = mcp_provider_registry @@ -79,44 +72,14 @@ def timeout_sec(self) -> float | None: return self._tool_config.timeout_sec @staticmethod - def tool_call_count(completion_response: Any) -> int: - """Count the number of tool calls in a completion response. - - Args: - completion_response: The LLM completion response (litellm.ModelResponse). - - Returns: - Number of tool calls in the response (0 if none). - """ - message = completion_response.choices[0].message - tool_calls = getattr(message, "tool_calls", None) - if tool_calls is None: - return 0 - return len(tool_calls) + def get_tool_call_count(completion_response: ChatCompletionResponse) -> int: + """Count the number of tool calls in a completion response.""" + return len(completion_response.message.tool_calls) @staticmethod - def has_tool_calls(completion_response: Any) -> bool: + def has_tool_calls(completion_response: ChatCompletionResponse) -> bool: """Returns True if tool calls are present in the completion response.""" - return MCPFacade.tool_call_count(completion_response) > 0 - - def _resolve_provider(self, provider: MCPProviderT) -> MCPProviderT: - """Resolve secret references in an MCP provider's api_key. - - Creates a copy of the provider with the api_key resolved from any secret - reference (e.g., "env:API_KEY") to its actual value. - - Args: - provider: The MCP provider config. - - Returns: - A copy of the provider with resolved api_key, or the original provider - if no api_key is configured. - """ - api_key_ref = getattr(provider, "api_key", None) - if not api_key_ref: - return provider - resolved_key = self._secret_resolver.resolve(api_key_ref) - return provider.model_copy(update={"api_key": resolved_key}) + return len(completion_response.message.tool_calls) > 0 def get_tool_schemas(self) -> list[dict[str, Any]]: """Get OpenAI-compatible tool schemas for this configuration. @@ -168,7 +131,7 @@ def get_tool_schemas(self) -> list[dict[str, Any]]: def process_completion_response( self, - completion_response: Any, + completion_response: ChatCompletionResponse, ) -> list[ChatMessage]: """Process an LLM completion response and execute any tool calls. @@ -178,10 +141,7 @@ def process_completion_response( tool calls), and returns the messages for continuing the conversation. Args: - completion_response: The completion response object from the LLM, - typically from `router.completion()`. Expected to have a - `choices[0].message` structure with optional `content`, - `reasoning_content`, and `tool_calls` attributes. + completion_response: The canonical ChatCompletionResponse from the model client. Returns: A list of ChatMessages to append to the conversation history: @@ -189,29 +149,23 @@ def process_completion_response( - If no tool calls: [assistant_message] Raises: - MCPToolError: If a tool call is missing a name. - MCPToolError: If tool call arguments cannot be parsed as JSON. - MCPToolError: If tool call arguments are an unsupported type. MCPToolError: If a requested tool is not in the allowed tools list. MCPToolError: If tool execution fails or times out. MCPConfigurationError: If a requested tool is not found on any configured provider. """ - message = completion_response.choices[0].message + message = completion_response.message - # Extract response content and reasoning content response_content = message.content or "" - reasoning_content = getattr(message, "reasoning_content", None) + reasoning_content = message.reasoning_content # Strip whitespace if reasoning is present (models often add extra newlines) if reasoning_content: response_content = response_content.strip() reasoning_content = reasoning_content.strip() - # Extract and normalize tool calls - tool_calls = self._extract_tool_calls(message) + tool_calls = message.tool_calls if not tool_calls: - # No tool calls - just return the assistant message return [ ChatMessage.as_assistant( content=response_content, @@ -220,49 +174,43 @@ def process_completion_response( ] # Has tool calls - execute and return all messages - assistant_message = self._build_assistant_tool_message(response_content, tool_calls, reasoning_content) - tool_messages = self._execute_tool_calls_internal(tool_calls) + tool_call_dicts = _convert_canonical_tool_calls_to_dicts(tool_calls) + assistant_message = self._build_assistant_tool_message(response_content, tool_call_dicts, reasoning_content) + tool_messages = self._execute_tool_calls_from_canonical(tool_calls) return [assistant_message, *tool_messages] def refuse_completion_response( self, - completion_response: Any, + completion_response: ChatCompletionResponse, refusal_message: str | None = None, ) -> list[ChatMessage]: """Refuse tool calls without executing them. Used when the tool call turn budget is exhausted. Returns messages that include the assistant's tool call request but with refusal - responses instead of actual tool results. This allows the model - to gracefully degrade and provide a final response without tools. + responses instead of actual tool results. Args: - completion_response: The LLM completion response containing tool calls. - refusal_message: Optional custom refusal message. Defaults to a - standard message about tool budget exhaustion. + completion_response: The canonical ChatCompletionResponse containing tool calls. + refusal_message: Optional custom refusal message. Returns: - A list of ChatMessages to append to the conversation history: - - If tool calls were present: [assistant_message_with_tool_calls, *refusal_messages] - - If no tool calls: [assistant_message] + A list of ChatMessages to append to the conversation history. """ - message = completion_response.choices[0].message + message = completion_response.message - # Extract response content and reasoning content response_content = message.content or "" - reasoning_content = getattr(message, "reasoning_content", None) + reasoning_content = message.reasoning_content # Strip whitespace if reasoning is present (models often add extra newlines) if reasoning_content: response_content = response_content.strip() reasoning_content = reasoning_content.strip() - # Extract and normalize tool calls - tool_calls = self._extract_tool_calls(message) + tool_calls = message.tool_calls if not tool_calls: - # No tool calls to refuse - just return assistant message return [ ChatMessage.as_assistant( content=response_content, @@ -271,115 +219,22 @@ def refuse_completion_response( ] # Build assistant message with tool calls (same as normal) - assistant_message = self._build_assistant_tool_message(response_content, tool_calls, reasoning_content) + tool_call_dicts = _convert_canonical_tool_calls_to_dicts(tool_calls) + assistant_message = self._build_assistant_tool_message(response_content, tool_call_dicts, reasoning_content) # Build refusal messages instead of executing tools refusal = refusal_message or DEFAULT_TOOL_REFUSAL_MESSAGE - tool_messages = [ChatMessage.as_tool(content=refusal, tool_call_id=tc["id"]) for tc in tool_calls] + tool_messages = [ChatMessage.as_tool(content=refusal, tool_call_id=tc.id) for tc in tool_calls] return [assistant_message, *tool_messages] - def _extract_tool_calls(self, message: Any) -> list[dict[str, Any]]: - """Extract and normalize tool calls from an LLM response message. - - Handles various LLM response formats (dict or object with attributes) - and normalizes them into a consistent dictionary format. Supports - parallel tool calling where the model returns multiple tool calls - in a single response. - - Args: - message: The LLM response message, either as a dictionary or an object - with a 'tool_calls' attribute. - - Returns: - A list of normalized tool call dictionaries. Each dictionary contains: - - 'id': Unique identifier for the tool call (generated if not provided) - - 'name': The name of the tool to call - - 'arguments': Parsed arguments as a dictionary - - 'arguments_json': Arguments serialized as a JSON string - - Returns an empty list if no tool calls are present in the message. - - Raises: - MCPToolError: If a tool call is missing a name. - MCPToolError: If tool call arguments cannot be parsed as JSON. - MCPToolError: If tool call arguments are an unsupported type. - """ - raw_tool_calls = getattr(message, "tool_calls", None) - if raw_tool_calls is None and isinstance(message, dict): - raw_tool_calls = message.get("tool_calls") - if not raw_tool_calls: - return [] - - tool_calls: list[dict[str, Any]] = [] - for raw_tool_call in raw_tool_calls: - tool_calls.append(self._normalize_tool_call(raw_tool_call)) - return tool_calls - - def _normalize_tool_call(self, raw_tool_call: Any) -> dict[str, Any]: - """Normalize a tool call from various LLM response formats. - - Handles both dictionary and object representations of tool calls, - supporting the OpenAI format (with nested 'function' key) and - flattened formats. - - Args: - raw_tool_call: A tool call in any supported format. - - Returns: - A normalized tool call dictionary with keys: - - 'id': Tool call identifier (UUID generated if not provided) - - 'name': The tool name - - 'arguments': Parsed arguments dictionary - - 'arguments_json': JSON string of arguments - - Raises: - MCPToolError: If the tool call is missing a name or has invalid - arguments that cannot be parsed as JSON. - """ - if isinstance(raw_tool_call, dict): - tool_call_id = raw_tool_call.get("id") - function = raw_tool_call.get("function") or {} - name = function.get("name") or raw_tool_call.get("name") - arguments = function.get("arguments") or raw_tool_call.get("arguments") - else: - tool_call_id = getattr(raw_tool_call, "id", None) - function = getattr(raw_tool_call, "function", None) - name = getattr(function, "name", None) if function is not None else getattr(raw_tool_call, "name", None) - arguments = ( - getattr(function, "arguments", None) - if function is not None - else getattr(raw_tool_call, "arguments", None) - ) - - if not name: - raise MCPToolError("MCP tool call is missing a tool name.") - - arguments_payload: dict[str, Any] - if arguments is None or arguments == "": - arguments_payload = {} - elif isinstance(arguments, str): - try: - arguments_payload = json.loads(arguments) - except json.JSONDecodeError as exc: - raise MCPToolError(f"Invalid tool arguments for '{name}': {arguments}") from exc - elif isinstance(arguments, dict): - arguments_payload = arguments - else: - raise MCPToolError(f"Unsupported tool arguments type for '{name}': {type(arguments)!r}") - - # Normalize arguments_json to ensure valid, canonical JSON - try: - arguments_json = json.dumps(arguments_payload) - except TypeError as exc: - raise MCPToolError(f"Non-serializable tool arguments for '{name}': {exc}") from exc - - return { - "id": tool_call_id or uuid.uuid4().hex, - "name": name, - "arguments": arguments_payload, - "arguments_json": arguments_json, - } + def _resolve_provider(self, provider: MCPProviderT) -> MCPProviderT: + """Resolve secret references in an MCP provider's api_key.""" + api_key_ref = getattr(provider, "api_key", None) + if not api_key_ref: + return provider + resolved_key = self._secret_resolver.resolve(api_key_ref) + return provider.model_copy(update={"api_key": resolved_key}) def _build_assistant_tool_message( self, @@ -387,21 +242,7 @@ def _build_assistant_tool_message( tool_calls: list[dict[str, Any]], reasoning_content: str | None = None, ) -> ChatMessage: - """Build the assistant message containing tool call requests. - - Constructs a message in the format expected by the LLM conversation - history, representing the assistant's request to call tools. - - Args: - response: The assistant's text response content. May be empty if - the assistant only requested tool calls without additional text. - tool_calls: List of normalized tool call dictionaries. - reasoning_content: Optional reasoning content from the assistant's - response. If provided, will be included under the 'reasoning_content' key. - - Returns: - A ChatMessage representing the assistant message with tool call requests. - """ + """Build the assistant message containing tool call requests.""" tool_calls_payload = [ { "id": tool_call["id"], @@ -416,38 +257,23 @@ def _build_assistant_tool_message( tool_calls=tool_calls_payload, ) - def _execute_tool_calls_internal( + def _execute_tool_calls_from_canonical( self, - tool_calls: list[dict[str, Any]], + tool_calls: list[ToolCall], ) -> list[ChatMessage]: - """Execute tool calls in parallel and return tool response messages. - - Validates all tool calls, then executes them concurrently via the io module - using call_tools_parallel. This leverages parallel tool calling when the - model returns multiple tool calls in a single response. - - Args: - tool_calls: List of normalized tool call dictionaries to execute. - - Returns: - A list of tool response messages, one per tool call. - - Raises: - MCPToolError: If a tool is not in the allowed tools list or if - the MCP provider returns an error. - """ + """Execute canonical ToolCall objects and return tool response messages.""" allowed_tools = set(self._tool_config.allow_tools) if self._tool_config.allow_tools else None - # Validate all tool calls and collect provider + args calls_to_execute: list[tuple[MCPProviderT, str, dict[str, Any], str]] = [] - for tool_call in tool_calls: - tool_name = tool_call["name"] - if allowed_tools is not None and tool_name not in allowed_tools: + for tc in tool_calls: + if allowed_tools is not None and tc.name not in allowed_tools: providers_str = ", ".join(repr(p) for p in self._tool_config.providers) - raise MCPToolError(f"Tool {tool_name!r} is not permitted for providers: {providers_str}.") + raise MCPToolError(f"Tool {tc.name!r} is not permitted for providers: {providers_str}.") - resolved_provider = self._find_resolved_provider_for_tool(tool_name) - calls_to_execute.append((resolved_provider, tool_name, tool_call["arguments"], tool_call["id"])) + arguments_raw = json.loads(tc.arguments_json) if tc.arguments_json else {} + arguments = arguments_raw if isinstance(arguments_raw, dict) else {} + resolved_provider = self._find_resolved_provider_for_tool(tc.name) + calls_to_execute.append((resolved_provider, tc.name, arguments, tc.id)) # Execute all calls in parallel results = mcp_io.call_tools( @@ -455,24 +281,13 @@ def _execute_tool_calls_internal( timeout_sec=self._tool_config.timeout_sec, ) - # Build response messages return [ ChatMessage.as_tool(content=result.content, tool_call_id=call[3]) for result, call in zip(results, calls_to_execute) ] def _find_resolved_provider_for_tool(self, tool_name: str) -> MCPProviderT: - """Find the provider that has the given tool and return it with resolved api_key. - - Args: - tool_name: The name of the tool to find. - - Returns: - The provider object (with resolved api_key) that has the tool. - - Raises: - MCPConfigurationError: If no provider has the tool. - """ + """Find the provider that has the given tool and return it with resolved api_key.""" for provider_name in self._tool_config.providers: provider = self._mcp_provider_registry.get_provider(provider_name) resolved_provider = self._resolve_provider(provider) @@ -483,3 +298,15 @@ def _find_resolved_provider_for_tool(self, tool_name: str) -> MCPProviderT: return resolved_provider raise MCPConfigurationError(f"Tool {tool_name!r} not found on any configured provider.") + + +def _convert_canonical_tool_calls_to_dicts(tool_calls: list[ToolCall]) -> list[dict[str, Any]]: + """Convert canonical ToolCall objects to the internal dict format for ChatMessage.""" + return [ + { + "id": tc.id, + "name": tc.name, + "arguments_json": tc.arguments_json, + } + for tc in tool_calls + ] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index 9cd9bcc62..fc72e1e4b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -10,6 +10,7 @@ map_http_error_to_provider_error, map_http_status_to_provider_error_kind, ) +from data_designer.engine.models.clients.factory import create_model_client from data_designer.engine.models.clients.types import ( AssistantMessage, ChatCompletionRequest, @@ -25,12 +26,12 @@ ) __all__ = [ - "HttpResponse", "AssistantMessage", "ChatCompletionRequest", "ChatCompletionResponse", "EmbeddingRequest", "EmbeddingResponse", + "HttpResponse", "ImageGenerationRequest", "ImageGenerationResponse", "ImagePayload", @@ -39,6 +40,7 @@ "ProviderErrorKind", "ToolCall", "Usage", + "create_model_client", "map_http_error_to_provider_error", "map_http_status_to_provider_error_kind", ] diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py index a7bc73b63..f5b861b4e 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/litellm_bridge.py @@ -12,13 +12,13 @@ from data_designer.engine.models.clients.errors import ( ProviderError, ProviderErrorKind, + extract_message_from_exception_string, map_http_status_to_provider_error_kind, ) from data_designer.engine.models.clients.parsing import ( aextract_images_from_chat_response, aextract_images_from_image_response, aparse_chat_completion_response, - collect_non_none_optional_fields, extract_embedding_vector, extract_images_from_chat_response, extract_images_from_image_response, @@ -32,6 +32,7 @@ EmbeddingResponse, ImageGenerationRequest, ImageGenerationResponse, + TransportKwargs, ) logger = logging.getLogger(__name__) @@ -75,57 +76,67 @@ def supports_image_generation(self) -> bool: return True def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + transport = TransportKwargs.from_request(request) with _handle_non_provider_errors(self.provider_name): response = self._router.completion( model=request.model, messages=request.messages, - **collect_non_none_optional_fields(request), + extra_headers=transport.headers or None, + **transport.body, ) return parse_chat_completion_response(response) async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + transport = TransportKwargs.from_request(request) with _handle_non_provider_errors(self.provider_name): response = await self._router.acompletion( model=request.model, messages=request.messages, - **collect_non_none_optional_fields(request), + extra_headers=transport.headers or None, + **transport.body, ) return await aparse_chat_completion_response(response) def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + transport = TransportKwargs.from_request(request) with _handle_non_provider_errors(self.provider_name): response = self._router.embedding( model=request.model, input=request.inputs, - **collect_non_none_optional_fields(request), + extra_headers=transport.headers or None, + **transport.body, ) vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])] return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response) async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + transport = TransportKwargs.from_request(request) with _handle_non_provider_errors(self.provider_name): response = await self._router.aembedding( model=request.model, input=request.inputs, - **collect_non_none_optional_fields(request), + extra_headers=transport.headers or None, + **transport.body, ) vectors = [extract_embedding_vector(item) for item in getattr(response, "data", [])] return EmbeddingResponse(vectors=vectors, usage=extract_usage(getattr(response, "usage", None)), raw=response) def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) + transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) with _handle_non_provider_errors(self.provider_name): if request.messages is not None: response = self._router.completion( model=request.model, messages=request.messages, - **image_kwargs, + extra_headers=transport.headers or None, + **transport.body, ) else: response = self._router.image_generation( prompt=request.prompt, model=request.model, - **image_kwargs, + extra_headers=transport.headers or None, + **transport.body, ) if request.messages is not None: @@ -137,19 +148,21 @@ def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResp return ImageGenerationResponse(images=images, usage=usage, raw=response) async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - image_kwargs = collect_non_none_optional_fields(request, exclude=self._IMAGE_EXCLUDE) + transport = TransportKwargs.from_request(request, exclude=self._IMAGE_EXCLUDE) with _handle_non_provider_errors(self.provider_name): if request.messages is not None: response = await self._router.acompletion( model=request.model, messages=request.messages, - **image_kwargs, + extra_headers=transport.headers or None, + **transport.body, ) else: response = await self._router.aimage_generation( prompt=request.prompt, model=request.model, - **image_kwargs, + extra_headers=transport.headers or None, + **transport.body, ) if request.messages is not None: @@ -183,7 +196,7 @@ def _handle_non_provider_errors(provider_name: str) -> Iterator[None]: raise ProviderError( kind=kind, - message=str(exc), + message=extract_message_from_exception_string(str(exc)), status_code=status_code if isinstance(status_code, int) else None, provider_name=provider_name, cause=exc, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py index 8e8a3b0ac..da3c19383 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/errors.py @@ -5,8 +5,8 @@ import calendar import email.utils +import json import time -from dataclasses import dataclass from enum import Enum from data_designer.engine.models.clients.types import HttpResponse @@ -28,20 +28,26 @@ class ProviderErrorKind(str, Enum): UNSUPPORTED_CAPABILITY = "unsupported_capability" -@dataclass class ProviderError(Exception): - kind: ProviderErrorKind - message: str - status_code: int | None = None - provider_name: str | None = None - model_name: str | None = None - retry_after: float | None = None - cause: Exception | None = None - - def __post_init__(self) -> None: - Exception.__init__(self, self.message) - if self.cause is not None: - self.__cause__ = self.cause + def __init__( + self, + kind: ProviderErrorKind, + message: str, + status_code: int | None = None, + provider_name: str | None = None, + model_name: str | None = None, + retry_after: float | None = None, + cause: Exception | None = None, + ) -> None: + super().__init__(message) + self.kind = kind + self.message = message + self.status_code = status_code + self.provider_name = provider_name + self.model_name = model_name + self.retry_after = retry_after + if cause is not None: + self.__cause__ = cause def __str__(self) -> str: return self.message @@ -118,6 +124,31 @@ def map_http_error_to_provider_error( ) +def extract_message_from_exception_string(raw: str) -> str: + """Extract a human-readable message from a stringified LiteLLM exception. + + LiteLLM often formats errors as ``"Error code: 400 - {json}"``. This + mirrors the structured-key lookup in ``_extract_structured_message`` but + operates on a raw string instead of an ``HttpResponse``. + """ + json_start = raw.find("{") + if json_start != -1: + try: + payload = json.loads(raw[json_start:]) + except (json.JSONDecodeError, ValueError): + return raw + if isinstance(payload, dict): + for key in ("message", "error", "detail"): + value = payload.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + if isinstance(value, dict): + nested = value.get("message") + if isinstance(nested, str) and nested.strip(): + return nested.strip() + return raw + + def _extract_response_text(response: HttpResponse) -> str: # Try structured JSON extraction first — most providers return structured error # bodies and we want the human-readable message, not raw JSON. diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py new file mode 100644 index 000000000..c7e32ebcd --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py @@ -0,0 +1,50 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import data_designer.lazy_heavy_imports as lazy +from data_designer.config.models import ModelConfig +from data_designer.engine.model_provider import ModelProviderRegistry +from data_designer.engine.models.clients.adapters.litellm_bridge import LiteLLMBridgeClient +from data_designer.engine.models.clients.base import ModelClient +from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs +from data_designer.engine.secret_resolver import SecretResolver + + +def create_model_client( + model_config: ModelConfig, + secret_resolver: SecretResolver, + model_provider_registry: ModelProviderRegistry, +) -> ModelClient: + """Create a ModelClient for the given model configuration. + + Resolves the provider, API key, and constructs a LiteLLM router wrapped in + a LiteLLMBridgeClient adapter. + + Args: + model_config: The model configuration to create a client for. + secret_resolver: Resolver for secrets referenced in provider configs. + model_provider_registry: Registry of model provider configurations. + + Returns: + A ModelClient instance ready for use. + """ + provider = model_provider_registry.get_provider(model_config.provider) + api_key = None + if provider.api_key: + api_key = secret_resolver.resolve(provider.api_key) + api_key = api_key or "not-used-but-required" + + litellm_params = lazy.litellm.LiteLLM_Params( + model=f"{provider.provider_type}/{model_config.model}", + api_base=provider.endpoint, + api_key=api_key, + max_parallel_requests=model_config.inference_parameters.max_parallel_requests, + ) + deployment = { + "model_name": model_config.model, + "litellm_params": litellm_params.model_dump(), + } + router = CustomRouter([deployment], **LiteLLMRouterDefaultKwargs().model_dump()) + return LiteLLMBridgeClient(provider_name=provider.name, router=router) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py index b565ee87b..e5d74d440 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/parsing.py @@ -5,9 +5,9 @@ from __future__ import annotations -import dataclasses import json import logging +import uuid from typing import Any from data_designer.config.utils.image_helpers import ( @@ -206,7 +206,7 @@ def extract_tool_calls(raw_tool_calls: Any) -> list[ToolCall]: normalized_tool_calls: list[ToolCall] = [] for raw_tool_call in raw_tool_calls: - tool_call_id = get_value_from(raw_tool_call, "id") or "" + tool_call_id = get_value_from(raw_tool_call, "id") or uuid.uuid4().hex function = get_value_from(raw_tool_call, "function") name = get_value_from(function, "name") or "" arguments_value = get_value_from(function, "arguments") @@ -333,17 +333,3 @@ def get_first_value_or_none(values: Any) -> Any | None: if isinstance(values, list) and values: return values[0] return None - - -def collect_non_none_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: - """Extract non-None optional fields from a request dataclass, skipping *exclude*. - - The ``f.default is None`` check intentionally targets fields whose default is - ``None`` — i.e. truly optional kwargs the caller may or may not set. Fields with - non-``None`` defaults are not "optional" in this forwarding sense and are excluded. - """ - return { - f.name: v - for f in dataclasses.fields(request) - if f.name not in exclude and f.default is None and (v := getattr(request, f.name)) is not None - } diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py index 3df379910..d83c7a121 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/types.py @@ -3,8 +3,8 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any, Protocol +from dataclasses import dataclass, field, fields +from typing import Any, ClassVar, Protocol class HttpResponse(Protocol): @@ -54,6 +54,12 @@ class ChatCompletionRequest: temperature: float | None = None top_p: float | None = None max_tokens: int | None = None + stop: str | list[str] | None = None + seed: int | None = None + response_format: dict[str, Any] | None = None + frequency_penalty: float | None = None + presence_penalty: float | None = None + n: int | None = None timeout: float | None = None extra_body: dict[str, Any] | None = None extra_headers: dict[str, str] | None = None @@ -101,3 +107,55 @@ class ImageGenerationResponse: images: list[ImagePayload] usage: Usage | None = None raw: Any | None = None + + +# --------------------------------------------------------------------------- +# Transport preparation +# --------------------------------------------------------------------------- + + +@dataclass +class TransportKwargs: + """Pre-processed kwargs ready for an HTTP client call. + + Adapters call ``TransportKwargs.from_request(request)`` instead of + manually handling ``extra_body`` / ``extra_headers`` on every request type. + + - ``body``: API-level keyword arguments with ``extra_body`` keys merged + into the top level (mirroring how LiteLLM flattens them). + - ``headers``: Extra HTTP headers to attach to the outgoing request. + """ + + _META_FIELDS: ClassVar[frozenset[str]] = frozenset({"extra_body", "extra_headers"}) + + body: dict[str, Any] + headers: dict[str, str] + + @classmethod + def from_request(cls, request: Any, *, exclude: frozenset[str] = frozenset()) -> TransportKwargs: + """Build transport-ready kwargs from a canonical request dataclass. + + 1. Collects all non-None optional fields (respecting *exclude*). + 2. Pops ``extra_body`` and merges its keys into the top-level body dict. + 3. Pops ``extra_headers`` into a separate headers dict. + """ + optional_fields = cls._collect_optional_fields(request, exclude=exclude | cls._META_FIELDS) + + extra_body = getattr(request, "extra_body", None) or {} + extra_headers = getattr(request, "extra_headers", None) or {} + + return cls(body={**optional_fields, **extra_body}, headers=dict(extra_headers)) + + @staticmethod + def _collect_optional_fields(request: Any, *, exclude: frozenset[str] = frozenset()) -> dict[str, Any]: + """Extract non-None optional fields from a request dataclass, skipping *exclude*. + + Targets fields whose default is ``None`` — i.e. truly optional kwargs + the caller may or may not set. Fields with non-``None`` defaults are + not "optional" in this forwarding sense and are excluded. + """ + return { + f.name: v + for f in fields(request) + if f.name not in exclude and f.default is None and (v := getattr(request, f.name)) is not None + } diff --git a/packages/data-designer-engine/src/data_designer/engine/models/errors.py b/packages/data-designer-engine/src/data_designer/engine/models/errors.py index e29c81325..6ad084fa7 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/errors.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/errors.py @@ -6,12 +6,13 @@ import logging from collections.abc import Callable from functools import wraps -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NoReturn from pydantic import BaseModel import data_designer.lazy_heavy_imports as lazy from data_designer.engine.errors import DataDesignerError +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind if TYPE_CHECKING: import litellm @@ -34,8 +35,7 @@ def get_exception_primary_cause(exception: BaseException) -> BaseException: """ if exception.__cause__ is None: return exception - else: - return get_exception_primary_cause(exception.__cause__) + return get_exception_primary_cause(exception.__cause__) class GenerationValidationFailureError(Exception): ... @@ -124,7 +124,18 @@ def handle_llm_exceptions( ) err_msg_parser = DownstreamLLMExceptionMessageParser(model_name, model_provider_name, purpose) match exception: - # Common errors that can come from LiteLLM + # Canonical ProviderError from the client adapter layer + case ProviderError(): + _raise_from_provider_error( + exception, + exception.kind, + model_name, + model_provider_name, + purpose, + authentication_error, + ) + + # LiteLLM-specific errors (safety net during bridge period) case lazy.litellm.exceptions.APIError(): raise err_msg_parser.parse_api_error(exception, authentication_error) from None @@ -228,7 +239,7 @@ def catch_llm_exceptions(func: Callable) -> Callable: """ @wraps(func) - def wrapper(model_facade: Any, *args, **kwargs): + def wrapper(model_facade: Any, *args: Any, **kwargs: Any) -> Any: try: return func(model_facade, *args, **kwargs) except Exception as e: @@ -315,7 +326,7 @@ def parse_context_window_exceeded_error( ) def parse_api_error( - self, exception: litellm.exceptions.InternalServerError, auth_error_msg: FormattedLLMErrorMessage + self, exception: litellm.exceptions.APIError, auth_error_msg: FormattedLLMErrorMessage ) -> DataDesignerError: if "Error code: 403" in str(exception): return ModelAuthenticationError(auth_error_msg) @@ -326,3 +337,95 @@ def parse_api_error( solution=f"Try again in a few moments. Check with your model provider {self.model_provider_name!r} if the issue persists.", ) ) + + +def _raise_from_provider_error( + exception: ProviderError, + kind: ProviderErrorKind, + model_name: str, + model_provider_name: str, + purpose: str, + authentication_error: FormattedLLMErrorMessage, +) -> NoReturn: + """Map a canonical ProviderError to the appropriate DataDesignerError subclass.""" + _KIND_MAP: dict[ProviderErrorKind, type[DataDesignerError]] = { + ProviderErrorKind.RATE_LIMIT: ModelRateLimitError, + ProviderErrorKind.TIMEOUT: ModelTimeoutError, + ProviderErrorKind.NOT_FOUND: ModelNotFoundError, + ProviderErrorKind.PERMISSION_DENIED: ModelPermissionDeniedError, + ProviderErrorKind.UNSUPPORTED_PARAMS: ModelUnsupportedParamsError, + ProviderErrorKind.INTERNAL_SERVER: ModelInternalServerError, + ProviderErrorKind.UNPROCESSABLE_ENTITY: ModelUnprocessableEntityError, + ProviderErrorKind.API_CONNECTION: ModelAPIConnectionError, + } + + _MESSAGES: dict[ProviderErrorKind, tuple[str, str]] = { + ProviderErrorKind.RATE_LIMIT: ( + f"You have exceeded the rate limit for model {model_name!r} while {purpose}.", + "Wait and try again in a few moments.", + ), + ProviderErrorKind.TIMEOUT: ( + f"The request to model {model_name!r} timed out while {purpose}.", + "Check your connection and try again. You may need to increase the timeout setting for the model.", + ), + ProviderErrorKind.NOT_FOUND: ( + f"The specified model {model_name!r} could not be found while {purpose}.", + f"Check that the model name is correct and supported by your model provider {model_provider_name!r} and try again.", + ), + ProviderErrorKind.PERMISSION_DENIED: ( + f"Your API key was found to lack the necessary permissions to use model {model_name!r} while {purpose}.", + f"Use an API key that has the right permissions for the model or use a model the API key in use has access to in model provider {model_provider_name!r}.", + ), + ProviderErrorKind.UNSUPPORTED_PARAMS: ( + f"One or more of the parameters you provided were found to be unsupported by model {model_name!r} while {purpose}.", + f"Review the documentation for model provider {model_provider_name!r} and adjust your request.", + ), + ProviderErrorKind.INTERNAL_SERVER: ( + f"Model {model_name!r} is currently experiencing internal server issues while {purpose}.", + f"Try again in a few moments. Check with your model provider {model_provider_name!r} if the issue persists.", + ), + ProviderErrorKind.UNPROCESSABLE_ENTITY: ( + f"The request to model {model_name!r} failed despite correct request format while {purpose}.", + "This is most likely temporary. Try again in a few moments.", + ), + ProviderErrorKind.API_CONNECTION: ( + f"Connection to model {model_name!r} hosted on model provider {model_provider_name!r} failed while {purpose}.", + "Check your network/proxy/firewall settings.", + ), + } + + if kind == ProviderErrorKind.AUTHENTICATION: + raise ModelAuthenticationError(authentication_error) from None + + if kind == ProviderErrorKind.CONTEXT_WINDOW_EXCEEDED: + raise ModelContextWindowExceededError( + FormattedLLMErrorMessage( + cause=f"The input data for model '{model_name}' was found to exceed its supported context width while {purpose}.", + solution="Check the model's supported max context width. Adjust the length of your input along with completions and try again.", + ) + ) from None + + if kind == ProviderErrorKind.BAD_REQUEST: + err_msg = FormattedLLMErrorMessage( + cause=f"The request for model {model_name!r} was found to be malformed or missing required parameters while {purpose}.", + solution="Check your request parameters and try again.", + ) + if "is not a multimodal model" in str(exception): + err_msg = FormattedLLMErrorMessage( + cause=f"Model {model_name!r} is not a multimodal model, but it looks like you are trying to provide multimodal context while {purpose}.", + solution="Check your request parameters and try again.", + ) + raise ModelBadRequestError(err_msg) from None + + if kind in _KIND_MAP and kind in _MESSAGES: + error_cls = _KIND_MAP[kind] + cause_str, solution_str = _MESSAGES[kind] + raise error_cls(FormattedLLMErrorMessage(cause=cause_str, solution=solution_str)) from None + + # Fallback for API_ERROR and UNSUPPORTED_CAPABILITY + raise ModelAPIError( + FormattedLLMErrorMessage( + cause=f"An unexpected API error occurred with model {model_name!r} while {purpose}.", + solution=f"Try again in a few moments. Check with your model provider {model_provider_name!r} if the issue persists.", + ) + ) from None diff --git a/packages/data-designer-engine/src/data_designer/engine/models/facade.py b/packages/data-designer-engine/src/data_designer/engine/models/facade.py index 13a0c1634..dd9b37e89 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/facade.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/facade.py @@ -9,16 +9,19 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any -import data_designer.lazy_heavy_imports as lazy from data_designer.config.models import GenerationType, ModelConfig, ModelProvider -from data_designer.config.utils.image_helpers import ( - extract_base64_from_data_uri, - is_base64_image, - is_image_diffusion_model, - load_image_url_to_base64, -) +from data_designer.config.utils.image_helpers import is_image_diffusion_model from data_designer.engine.mcp.errors import MCPConfigurationError from data_designer.engine.model_provider import ModelProviderRegistry +from data_designer.engine.models.clients.types import ( + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ImageGenerationRequest, + ImageGenerationResponse, + Usage, +) from data_designer.engine.models.errors import ( GenerationValidationFailureError, ImageGenerationError, @@ -26,17 +29,14 @@ catch_llm_exceptions, get_exception_primary_cause, ) -from data_designer.engine.models.litellm_overrides import CustomRouter, LiteLLMRouterDefaultKwargs from data_designer.engine.models.parsers.errors import ParserException from data_designer.engine.models.usage import ImageUsageStats, ModelUsageStats, RequestUsageStats, TokenUsageStats from data_designer.engine.models.utils import ChatMessage, prompt_to_messages -from data_designer.engine.secret_resolver import SecretResolver if TYPE_CHECKING: - import litellm - from data_designer.engine.mcp.facade import MCPFacade from data_designer.engine.mcp.registry import MCPRegistry + from data_designer.engine.models.clients.base import ModelClient def _identity(x: Any) -> Any: @@ -44,50 +44,42 @@ def _identity(x: Any) -> Any: return x -def _try_extract_base64(source: str | litellm.types.utils.ImageObject) -> str | None: - """Try to extract base64 image data from a data URI string or image response object. - - Args: - source: Either a data URI string (e.g. "data:image/png;base64,...") - or a litellm ImageObject with b64_json/url attributes. - - Returns: - Base64-encoded image string, or None if extraction fails. - """ - try: - if isinstance(source, str): - return extract_base64_from_data_uri(source) - - if getattr(source, "b64_json", None): - return source.b64_json - - if getattr(source, "url", None): - return load_image_url_to_base64(source.url) - except Exception: - logger.debug(f"Failed to extract base64 from source of type {type(source).__name__}") - return None - - return None +logger = logging.getLogger(__name__) -logger = logging.getLogger(__name__) +# Known keyword arguments extracted into ChatCompletionRequest fields. +_COMPLETION_REQUEST_FIELDS = frozenset( + { + "temperature", + "top_p", + "max_tokens", + "stop", + "seed", + "response_format", + "frequency_penalty", + "presence_penalty", + "n", + "timeout", + "tools", + "extra_body", + "extra_headers", + } +) class ModelFacade: def __init__( self, model_config: ModelConfig, - secret_resolver: SecretResolver, model_provider_registry: ModelProviderRegistry, *, + client: ModelClient, mcp_registry: MCPRegistry | None = None, ) -> None: self._model_config = model_config - self._secret_resolver = secret_resolver self._model_provider_registry = model_provider_registry + self._client = client self._mcp_registry = mcp_registry - self._litellm_deployment = self._get_litellm_deployment(model_config) - self._router = CustomRouter([self._litellm_deployment], **LiteLLMRouterDefaultKwargs().model_dump()) self._usage_stats = ModelUsageStats() @property @@ -118,9 +110,21 @@ def max_parallel_requests(self) -> int: def usage_stats(self) -> ModelUsageStats: return self._usage_stats + def consolidate_kwargs(self, **kwargs: Any) -> dict[str, Any]: + # Remove purpose from kwargs to avoid passing it to the model + kwargs.pop("purpose", None) + kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} + if self.model_provider.extra_body: + kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} + if self.model_provider.extra_headers: + kwargs["extra_headers"] = self.model_provider.extra_headers + return kwargs + + # --- completion / acompletion --- + def completion( - self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs - ) -> litellm.ModelResponse: + self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any + ) -> ChatCompletionResponse: message_payloads = [message.to_dict() for message in messages] logger.debug( f"Prompting model {self.model_name!r}...", @@ -129,32 +133,56 @@ def completion( response = None kwargs = self.consolidate_kwargs(**kwargs) try: - response = self._router.completion(model=self.model_name, messages=message_payloads, **kwargs) + request = self._build_chat_completion_request(message_payloads, kwargs) + response = self._client.completion(request) logger.debug( f"Received completion from model {self.model_name!r}", extra={ "model": self.model_name, "response": response, - "text": response.choices[0].message.content, + "text": response.message.content, "usage": self._usage_stats.model_dump(), }, ) return response - except Exception as e: - raise e finally: - if not skip_usage_tracking and response is not None: - self._track_token_usage_from_completion(response) + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) - def consolidate_kwargs(self, **kwargs) -> dict[str, Any]: - # Remove purpose from kwargs to avoid passing it to the model - kwargs.pop("purpose", None) - kwargs = {**self._model_config.inference_parameters.generate_kwargs, **kwargs} - if self.model_provider.extra_body: - kwargs["extra_body"] = {**kwargs.get("extra_body", {}), **self.model_provider.extra_body} - if self.model_provider.extra_headers: - kwargs["extra_headers"] = self.model_provider.extra_headers - return kwargs + async def acompletion( + self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any + ) -> ChatCompletionResponse: + message_payloads = [message.to_dict() for message in messages] + logger.debug( + f"Prompting model {self.model_name!r}...", + extra={"model": self.model_name, "messages": message_payloads}, + ) + response = None + kwargs = self.consolidate_kwargs(**kwargs) + try: + request = self._build_chat_completion_request(message_payloads, kwargs) + response = await self._client.acompletion(request) + logger.debug( + f"Received completion from model {self.model_name!r}", + extra={ + "model": self.model_name, + "response": response, + "text": response.message.content, + "usage": self._usage_stats.model_dump(), + }, + ) + return response + finally: + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) + + # --- generate / agenerate --- @catch_llm_exceptions def generate( @@ -169,7 +197,7 @@ def generate( max_conversation_restarts: int = 0, skip_usage_tracking: bool = False, purpose: str | None = None, - **kwargs, + **kwargs: Any, ) -> tuple[Any, list[ChatMessage]]: """Generate a parsed output with correction steps. @@ -251,7 +279,7 @@ def generate( # Process any tool calls in the response (handles parallel tool calling) if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response): tool_call_turns += 1 - total_tool_calls += mcp_facade.tool_call_count(completion_response) + total_tool_calls += mcp_facade.get_tool_call_count(completion_response) if tool_call_turns > mcp_facade.max_tool_call_turns: # Gracefully refuse tool calls when budget is exhausted @@ -266,8 +294,8 @@ def generate( continue # Back to top # No tool calls remaining to process - response = (completion_response.choices[0].message.content or "").strip() - reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None) + response = (completion_response.message.content or "").strip() + reasoning_trace = completion_response.message.reasoning_content messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) curr_num_correction_steps += 1 @@ -304,335 +332,6 @@ def generate( return output_obj, messages - @catch_llm_exceptions - def generate_text_embeddings( - self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs - ) -> list[list[float]]: - logger.debug( - f"Generating embeddings with model {self.model_name!r}...", - extra={ - "model": self.model_name, - "input_count": len(input_texts), - }, - ) - kwargs = self.consolidate_kwargs(**kwargs) - response = None - try: - response = self._router.embedding(model=self.model_name, input=input_texts, **kwargs) - logger.debug( - f"Received embeddings from model {self.model_name!r}", - extra={ - "model": self.model_name, - "embedding_count": len(response.data) if response.data else 0, - "usage": self._usage_stats.model_dump(), - }, - ) - if response.data and len(response.data) == len(input_texts): - return [data["embedding"] for data in response.data] - else: - raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") - except Exception as e: - raise e - finally: - if not skip_usage_tracking and response is not None: - self._track_token_usage_from_embedding(response) - - @catch_llm_exceptions - def generate_image( - self, - prompt: str, - multi_modal_context: list[dict[str, Any]] | None = None, - skip_usage_tracking: bool = False, - **kwargs, - ) -> list[str]: - """Generate image(s) and return base64-encoded data. - - Automatically detects the appropriate API based on model name: - - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) → image_generation API - - All other models → chat/completions API (default) - - Both paths return base64-encoded image data. If the API returns multiple images, - all are returned in the list. - - Args: - prompt: The prompt for image generation - multi_modal_context: Optional list of image contexts for multi-modal generation. - Only used with autoregressive models via chat completions API. - skip_usage_tracking: Whether to skip usage tracking - **kwargs: Additional arguments to pass to the model (including n=number of images) - - Returns: - List of base64-encoded image strings (without data URI prefix) - - Raises: - ImageGenerationError: If image generation fails or returns invalid data - """ - logger.debug( - f"Generating image with model {self.model_name!r}...", - extra={"model": self.model_name, "prompt": prompt}, - ) - - # Auto-detect API type based on model name - if is_image_diffusion_model(self.model_name): - images = self._generate_image_diffusion(prompt, skip_usage_tracking, **kwargs) - else: - images = self._generate_image_chat_completion(prompt, multi_modal_context, skip_usage_tracking, **kwargs) - - # Track image usage - if not skip_usage_tracking and len(images) > 0: - self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) - - return images - - def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: - if tool_alias is None: - return None - if self._mcp_registry is None: - raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.") - - try: - return self._mcp_registry.get_mcp(tool_alias=tool_alias) - except ValueError as exc: - raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc - - def _generate_image_chat_completion( - self, - prompt: str, - multi_modal_context: list[dict[str, Any]] | None = None, - skip_usage_tracking: bool = False, - **kwargs, - ) -> list[str]: - """Generate image(s) using autoregressive model via chat completions API. - - Args: - prompt: The prompt for image generation - multi_modal_context: Optional list of image contexts for multi-modal generation - skip_usage_tracking: Whether to skip usage tracking - **kwargs: Additional arguments to pass to the model - - Returns: - List of base64-encoded image strings - """ - messages = prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context) - - response = None - try: - response = self.completion( - messages=messages, - skip_usage_tracking=skip_usage_tracking, - **kwargs, - ) - - logger.debug( - f"Received image(s) from autoregressive model {self.model_name!r}", - extra={"model": self.model_name, "response": response}, - ) - - # Validate response structure - if not response.choices or len(response.choices) == 0: - raise ImageGenerationError("Image generation response missing choices") - - message = response.choices[0].message - images = [] - - # Extract base64 from images attribute (primary path) - if hasattr(message, "images") and message.images: - for image in message.images: - # Handle different response formats - if isinstance(image, dict) and "image_url" in image: - image_url = image["image_url"] - - if isinstance(image_url, dict) and "url" in image_url: - if (b64 := _try_extract_base64(image_url["url"])) is not None: - images.append(b64) - elif isinstance(image_url, str): - if (b64 := _try_extract_base64(image_url)) is not None: - images.append(b64) - # Fallback: treat as base64 string - elif isinstance(image, str): - if (b64 := _try_extract_base64(image)) is not None: - images.append(b64) - - # Fallback: check content field if it looks like image data - if not images: - content = message.content or "" - if content and (content.startswith("data:image/") or is_base64_image(content)): - if (b64 := _try_extract_base64(content)) is not None: - images.append(b64) - - if not images: - raise ImageGenerationError("No image data found in image generation response") - - return images - - except Exception: - raise - - def _generate_image_diffusion(self, prompt: str, skip_usage_tracking: bool = False, **kwargs) -> list[str]: - """Generate image(s) using diffusion model via image_generation API. - - Always returns base64. If the API returns URLs instead of inline base64, - the images are downloaded and converted automatically. - - Returns: - List of base64-encoded image strings - """ - kwargs = self.consolidate_kwargs(**kwargs) - - response = None - - try: - response = self._router.image_generation(prompt=prompt, model=self.model_name, **kwargs) - - logger.debug( - f"Received {len(response.data)} image(s) from diffusion model {self.model_name!r}", - extra={"model": self.model_name, "response": response}, - ) - - # Validate response - if not response.data or len(response.data) == 0: - raise ImageGenerationError("Image generation returned no data") - - images = [b64 for img in response.data if (b64 := _try_extract_base64(img)) is not None] - - if not images: - raise ImageGenerationError("No image data could be extracted from response") - - return images - - except Exception: - raise - finally: - if not skip_usage_tracking and response is not None: - self._track_token_usage_from_image_diffusion(response) - - def _get_litellm_deployment(self, model_config: ModelConfig) -> litellm.DeploymentTypedDict: - provider = self._model_provider_registry.get_provider(model_config.provider) - api_key = None - if provider.api_key: - api_key = self._secret_resolver.resolve(provider.api_key) - api_key = api_key or "not-used-but-required" - - litellm_params = lazy.litellm.LiteLLM_Params( - model=f"{provider.provider_type}/{model_config.model}", - api_base=provider.endpoint, - api_key=api_key, - max_parallel_requests=model_config.inference_parameters.max_parallel_requests, - ) - return { - "model_name": model_config.model, - "litellm_params": litellm_params.model_dump(), - } - - def _track_token_usage_from_completion(self, response: litellm.types.utils.ModelResponse | None) -> None: - if response is None: - self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) - return - if ( - response.usage is not None - and response.usage.prompt_tokens is not None - and response.usage.completion_tokens is not None - ): - self._usage_stats.extend( - token_usage=TokenUsageStats( - input_tokens=response.usage.prompt_tokens, - output_tokens=response.usage.completion_tokens, - ), - request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), - ) - - def _track_token_usage_from_embedding(self, response: litellm.types.utils.EmbeddingResponse | None) -> None: - if response is None: - self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) - return - if response.usage is not None and response.usage.prompt_tokens is not None: - self._usage_stats.extend( - token_usage=TokenUsageStats( - input_tokens=response.usage.prompt_tokens, - output_tokens=0, - ), - request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), - ) - - def _track_token_usage_from_image_diffusion(self, response: litellm.types.utils.ImageResponse | None) -> None: - """Track token usage from image_generation API response.""" - if response is None: - self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) - return - - if response.usage is not None and isinstance(response.usage, lazy.litellm.types.utils.ImageUsage): - self._usage_stats.extend( - token_usage=TokenUsageStats( - input_tokens=response.usage.input_tokens, - output_tokens=response.usage.output_tokens, - ), - request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), - ) - else: - # Successful response but no token usage data (some providers don't report it) - self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=1, failed_requests=0)) - - async def acompletion( - self, messages: list[ChatMessage], skip_usage_tracking: bool = False, **kwargs: Any - ) -> litellm.ModelResponse: - message_payloads = [message.to_dict() for message in messages] - logger.debug( - f"Prompting model {self.model_name!r}...", - extra={"model": self.model_name, "messages": message_payloads}, - ) - response = None - kwargs = self.consolidate_kwargs(**kwargs) - try: - response = await self._router.acompletion(model=self.model_name, messages=message_payloads, **kwargs) - logger.debug( - f"Received completion from model {self.model_name!r}", - extra={ - "model": self.model_name, - "response": response, - "text": response.choices[0].message.content, - "usage": self._usage_stats.model_dump(), - }, - ) - return response - except Exception as e: - raise e - finally: - if not skip_usage_tracking and response is not None: - self._track_token_usage_from_completion(response) - - @acatch_llm_exceptions - async def agenerate_text_embeddings( - self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs: Any - ) -> list[list[float]]: - logger.debug( - f"Generating embeddings with model {self.model_name!r}...", - extra={ - "model": self.model_name, - "input_count": len(input_texts), - }, - ) - kwargs = self.consolidate_kwargs(**kwargs) - response = None - try: - response = await self._router.aembedding(model=self.model_name, input=input_texts, **kwargs) - logger.debug( - f"Received embeddings from model {self.model_name!r}", - extra={ - "model": self.model_name, - "embedding_count": len(response.data) if response.data else 0, - "usage": self._usage_stats.model_dump(), - }, - ) - if response.data and len(response.data) == len(input_texts): - return [data["embedding"] for data in response.data] - else: - raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.data)}") - except Exception as e: - raise e - finally: - if not skip_usage_tracking and response is not None: - self._track_token_usage_from_embedding(response) - @acatch_llm_exceptions async def agenerate( self, @@ -679,7 +378,7 @@ async def agenerate( if mcp_facade is not None and mcp_facade.has_tool_calls(completion_response): tool_call_turns += 1 - total_tool_calls += mcp_facade.tool_call_count(completion_response) + total_tool_calls += mcp_facade.get_tool_call_count(completion_response) if tool_call_turns > mcp_facade.max_tool_call_turns: messages.extend(mcp_facade.refuse_completion_response(completion_response)) @@ -693,8 +392,8 @@ async def agenerate( continue - response = (completion_response.choices[0].message.content or "").strip() - reasoning_trace = getattr(completion_response.choices[0].message, "reasoning_content", None) + response = (completion_response.message.content or "").strip() + reasoning_trace = completion_response.message.reasoning_content messages.append(ChatMessage.as_assistant(content=response, reasoning_content=reasoning_trace or None)) curr_num_correction_steps += 1 @@ -730,19 +429,91 @@ async def agenerate( return output_obj, messages + # --- generate_text_embeddings / agenerate_text_embeddings --- + + @catch_llm_exceptions + def generate_text_embeddings( + self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs: Any + ) -> list[list[float]]: + logger.debug( + f"Generating embeddings with model {self.model_name!r}...", + extra={ + "model": self.model_name, + "input_count": len(input_texts), + }, + ) + kwargs = self.consolidate_kwargs(**kwargs) + response: EmbeddingResponse | None = None + try: + request = self._build_embedding_request(input_texts, kwargs) + response = self._client.embeddings(request) + logger.debug( + f"Received embeddings from model {self.model_name!r}", + extra={ + "model": self.model_name, + "embedding_count": len(response.vectors), + "usage": self._usage_stats.model_dump(), + }, + ) + if len(response.vectors) == len(input_texts): + return response.vectors + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.vectors)}") + finally: + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) + @acatch_llm_exceptions - async def agenerate_image( + async def agenerate_text_embeddings( + self, input_texts: list[str], skip_usage_tracking: bool = False, **kwargs: Any + ) -> list[list[float]]: + logger.debug( + f"Generating embeddings with model {self.model_name!r}...", + extra={ + "model": self.model_name, + "input_count": len(input_texts), + }, + ) + kwargs = self.consolidate_kwargs(**kwargs) + response: EmbeddingResponse | None = None + try: + request = self._build_embedding_request(input_texts, kwargs) + response = await self._client.aembeddings(request) + logger.debug( + f"Received embeddings from model {self.model_name!r}", + extra={ + "model": self.model_name, + "embedding_count": len(response.vectors), + "usage": self._usage_stats.model_dump(), + }, + ) + if len(response.vectors) == len(input_texts): + return response.vectors + raise ValueError(f"Expected {len(input_texts)} embeddings, but received {len(response.vectors)}") + finally: + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) + + # --- generate_image / agenerate_image --- + + @catch_llm_exceptions + def generate_image( self, prompt: str, multi_modal_context: list[dict[str, Any]] | None = None, skip_usage_tracking: bool = False, **kwargs: Any, ) -> list[str]: - """Async version of generate_image. Generate image(s) and return base64-encoded data. + """Generate image(s) and return base64-encoded data. Automatically detects the appropriate API based on model name: - - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) → image_generation API - - All other models → chat/completions API (default) + - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) -> image_generation API + - All other models -> chat/completions API (default) Both paths return base64-encoded image data. If the API returns multiple images, all are returned in the list. @@ -765,133 +536,189 @@ async def agenerate_image( extra={"model": self.model_name, "prompt": prompt}, ) - # Auto-detect API type based on model name - if is_image_diffusion_model(self.model_name): - images = await self._agenerate_image_diffusion(prompt, skip_usage_tracking, **kwargs) - else: - images = await self._agenerate_image_chat_completion( - prompt, multi_modal_context, skip_usage_tracking, **kwargs - ) + kwargs = self.consolidate_kwargs(**kwargs) + response: ImageGenerationResponse | None = None + try: + request = self._build_image_generation_request(prompt, multi_modal_context, kwargs) + response = self._client.generate_image(request) + + images = [img.b64_data for img in response.images] - # Track image usage - if not skip_usage_tracking and len(images) > 0: - self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) + if not images: + raise ImageGenerationError("No image data found in image generation response") - return images + if not skip_usage_tracking: + self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) - async def _agenerate_image_chat_completion( + return images + finally: + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) + + @acatch_llm_exceptions + async def agenerate_image( self, prompt: str, multi_modal_context: list[dict[str, Any]] | None = None, skip_usage_tracking: bool = False, **kwargs: Any, ) -> list[str]: - """Async version of _generate_image_chat_completion. + """Async version of generate_image. Generate image(s) and return base64-encoded data. - Generate image(s) using autoregressive model via chat completions API. + Automatically detects the appropriate API based on model name: + - Diffusion models (DALL-E, Stable Diffusion, Imagen, etc.) -> image_generation API + - All other models -> chat/completions API (default) + + Both paths return base64-encoded image data. If the API returns multiple images, + all are returned in the list. Args: prompt: The prompt for image generation - multi_modal_context: Optional list of image contexts for multi-modal generation + multi_modal_context: Optional list of image contexts for multi-modal generation. + Only used with autoregressive models via chat completions API. skip_usage_tracking: Whether to skip usage tracking - **kwargs: Additional arguments to pass to the model + **kwargs: Additional arguments to pass to the model (including n=number of images) Returns: - List of base64-encoded image strings + List of base64-encoded image strings (without data URI prefix) + + Raises: + ImageGenerationError: If image generation fails or returns invalid data """ - messages = prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context) + logger.debug( + f"Generating image with model {self.model_name!r}...", + extra={"model": self.model_name, "prompt": prompt}, + ) - response = None + kwargs = self.consolidate_kwargs(**kwargs) + response: ImageGenerationResponse | None = None try: - response = await self.acompletion( - messages=messages, - skip_usage_tracking=skip_usage_tracking, - **kwargs, - ) - - logger.debug( - f"Received image(s) from autoregressive model {self.model_name!r}", - extra={"model": self.model_name, "response": response}, - ) + request = self._build_image_generation_request(prompt, multi_modal_context, kwargs) + response = await self._client.agenerate_image(request) - # Validate response structure - if not response.choices or len(response.choices) == 0: - raise ImageGenerationError("Image generation response missing choices") - - message = response.choices[0].message - images = [] - - # Extract base64 from images attribute (primary path) - if hasattr(message, "images") and message.images: - for image in message.images: - # Handle different response formats - if isinstance(image, dict) and "image_url" in image: - image_url = image["image_url"] - - if isinstance(image_url, dict) and "url" in image_url: - if (b64 := _try_extract_base64(image_url["url"])) is not None: - images.append(b64) - elif isinstance(image_url, str): - if (b64 := _try_extract_base64(image_url)) is not None: - images.append(b64) - # Fallback: treat as base64 string - elif isinstance(image, str): - if (b64 := _try_extract_base64(image)) is not None: - images.append(b64) - - # Fallback: check content field if it looks like image data - if not images: - content = message.content or "" - if content and (content.startswith("data:image/") or is_base64_image(content)): - if (b64 := _try_extract_base64(content)) is not None: - images.append(b64) + images = [img.b64_data for img in response.images] if not images: raise ImageGenerationError("No image data found in image generation response") + if not skip_usage_tracking: + self._usage_stats.extend(image_usage=ImageUsageStats(total_images=len(images))) + return images + finally: + if not skip_usage_tracking: + self._track_usage( + response.usage if response is not None else None, + is_request_successful=response is not None, + ) - except Exception: - raise + # --- close / aclose --- - async def _agenerate_image_diffusion( - self, prompt: str, skip_usage_tracking: bool = False, **kwargs: Any - ) -> list[str]: - """Async version of _generate_image_diffusion. + def close(self) -> None: + """Release resources held by the underlying client.""" + self._client.close() - Generate image(s) using diffusion model via image_generation API. + async def aclose(self) -> None: + """Async release resources held by the underlying client.""" + await self._client.aclose() - Always returns base64. If the API returns URLs instead of inline base64, - the images are downloaded and converted automatically. + # --- private helpers --- - Returns: - List of base64-encoded image strings - """ - kwargs = self.consolidate_kwargs(**kwargs) - - response = None + def _get_mcp_facade(self, tool_alias: str | None) -> MCPFacade | None: + if tool_alias is None: + return None + if self._mcp_registry is None: + raise MCPConfigurationError(f"Tool alias {tool_alias!r} specified but no MCPRegistry configured.") try: - response = await self._router.aimage_generation(prompt=prompt, model=self.model_name, **kwargs) + return self._mcp_registry.get_mcp(tool_alias=tool_alias) + except ValueError as exc: + raise MCPConfigurationError(f"Tool alias {tool_alias!r} is not registered.") from exc + def _build_chat_completion_request( + self, messages: list[dict[str, Any]], kwargs: dict[str, Any] + ) -> ChatCompletionRequest: + """Build a ChatCompletionRequest from message payloads and consolidated kwargs.""" + request_fields: dict[str, Any] = {"model": self.model_name, "messages": messages} + metadata: dict[str, Any] = {} + + for key, value in kwargs.items(): + if key in _COMPLETION_REQUEST_FIELDS: + request_fields[key] = value + else: + metadata[key] = value + + if metadata: logger.debug( - f"Received {len(response.data)} image(s) from diffusion model {self.model_name!r}", - extra={"model": self.model_name, "response": response}, + "Unknown kwargs %s routed to LiteLLM metadata (not forwarded as model parameters). " + "Use 'extra_body' to pass non-standard parameters to the model.", + sorted(metadata.keys()), ) + request_fields["metadata"] = metadata + + return ChatCompletionRequest(**request_fields) + + def _build_embedding_request(self, input_texts: list[str], kwargs: dict[str, Any]) -> EmbeddingRequest: + """Build an EmbeddingRequest from input texts and consolidated kwargs.""" + return EmbeddingRequest( + model=self.model_name, + inputs=input_texts, + encoding_format=kwargs.get("encoding_format"), + dimensions=kwargs.get("dimensions"), + timeout=kwargs.get("timeout"), + extra_body=kwargs.get("extra_body"), + extra_headers=kwargs.get("extra_headers"), + ) - # Validate response - if not response.data or len(response.data) == 0: - raise ImageGenerationError("Image generation returned no data") + def _build_image_generation_request( + self, + prompt: str, + multi_modal_context: list[dict[str, Any]] | None, + kwargs: dict[str, Any], + ) -> ImageGenerationRequest: + """Build an ImageGenerationRequest, choosing chat-completion vs diffusion path.""" + is_diffusion = is_image_diffusion_model(self.model_name) + + if is_diffusion: + return ImageGenerationRequest( + model=self.model_name, + prompt=prompt, + n=kwargs.get("n"), + timeout=kwargs.get("timeout"), + extra_body=kwargs.get("extra_body"), + extra_headers=kwargs.get("extra_headers"), + ) - images = [b64 for img in response.data if (b64 := _try_extract_base64(img)) is not None] + chat_messages = [ + m.to_dict() for m in prompt_to_messages(user_prompt=prompt, multi_modal_context=multi_modal_context) + ] + return ImageGenerationRequest( + model=self.model_name, + prompt=prompt, + messages=chat_messages, + n=kwargs.get("n"), + timeout=kwargs.get("timeout"), + extra_body=kwargs.get("extra_body"), + extra_headers=kwargs.get("extra_headers"), + ) - if not images: - raise ImageGenerationError("No image data could be extracted from response") + def _track_usage(self, usage: Usage | None, *, is_request_successful: bool) -> None: + """Unified usage tracking from canonical Usage type.""" + if not is_request_successful: + self._usage_stats.extend(request_usage=RequestUsageStats(successful_requests=0, failed_requests=1)) + return - return images + token_usage = None + if usage is not None and usage.input_tokens is not None: + token_usage = TokenUsageStats( + input_tokens=usage.input_tokens, + output_tokens=usage.output_tokens or 0, + ) - except Exception: - raise - finally: - if not skip_usage_tracking and response is not None: - self._track_token_usage_from_image_diffusion(response) + self._usage_stats.extend( + token_usage=token_usage, + request_usage=RequestUsageStats(successful_requests=1, failed_requests=0), + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/factory.py b/packages/data-designer-engine/src/data_designer/engine/models/factory.py index fb3b2e1d4..a23c0dbcc 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/factory.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/factory.py @@ -37,17 +37,23 @@ def create_model_registry( Returns: A configured ModelRegistry instance. """ + from data_designer.engine.models.clients.factory import create_model_client from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.litellm_overrides import apply_litellm_patches from data_designer.engine.models.registry import ModelRegistry apply_litellm_patches() - def model_facade_factory(model_config, secret_resolver, model_provider_registry): + def model_facade_factory( + model_config: ModelConfig, + secret_resolver: SecretResolver, + model_provider_registry: ModelProviderRegistry, + ) -> ModelFacade: + client = create_model_client(model_config, secret_resolver, model_provider_registry) return ModelFacade( model_config, - secret_resolver, model_provider_registry, + client=client, mcp_registry=mcp_registry, ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/registry.py b/packages/data-designer-engine/src/data_designer/engine/models/registry.py index 0b103e76b..b4dff0301 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/registry.py @@ -5,7 +5,7 @@ import logging from collections.abc import Callable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from data_designer.config.models import GenerationType, ModelConfig from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry @@ -27,7 +27,7 @@ def __init__( model_provider_registry: ModelProviderRegistry, model_configs: list[ModelConfig] | None = None, model_facade_factory: Callable[[ModelConfig, SecretResolver, ModelProviderRegistry], ModelFacade] | None = None, - ): + ) -> None: self._secret_resolver = secret_resolver self._model_provider_registry = model_provider_registry self._model_facade_factory = model_facade_factory @@ -69,7 +69,7 @@ def get_model_config(self, *, model_alias: str) -> ModelConfig: raise ValueError(f"No model config with alias {model_alias!r} found!") return self._model_configs[model_alias] - def get_model_usage_stats(self, total_time_elapsed: float) -> dict[str, dict]: + def get_model_usage_stats(self, total_time_elapsed: float) -> dict[str, dict[str, Any]]: return { model.model_name: model.usage_stats.get_usage_stats(total_time_elapsed=total_time_elapsed) for model in self._models.values() @@ -200,10 +200,18 @@ def run_health_check(self, model_aliases: list[str]) -> None: logger.error(f"{LOG_INDENT}❌ Failed!") raise e - def _set_model_configs(self, model_configs: list[ModelConfig]) -> None: - model_configs = model_configs or [] - self._model_configs = {mc.alias: mc for mc in model_configs} - # Models are now lazily initialized in get_model() when first requested + def _set_model_configs(self, model_configs: list[ModelConfig] | None) -> None: + self._model_configs = {mc.alias: mc for mc in (model_configs or [])} + + def close(self) -> None: + """Release resources held by all model facades.""" + for facade in self._models.values(): + facade.close() + + async def aclose(self) -> None: + """Async release resources held by all model facades.""" + for facade in self._models.values(): + await facade.aclose() def _get_model(self, model_config: ModelConfig) -> ModelFacade: if self._model_facade_factory is None: diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py b/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py index 3d01db6ac..a3380a140 100644 --- a/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/testing/__init__.py @@ -10,6 +10,7 @@ StubMCPRegistry, StubMessage, StubResponse, + make_stub_completion_response, ) from data_designer.engine.testing.utils import assert_valid_plugin @@ -21,4 +22,5 @@ "StubMessage", "StubResponse", assert_valid_plugin.__name__, + make_stub_completion_response.__name__, ] diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py b/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py index af7d3ebfc..e47b2ffbc 100644 --- a/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py +++ b/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py @@ -16,8 +16,9 @@ from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, ToolConfig from data_designer.engine.mcp.facade import MCPFacade from data_designer.engine.model_provider import MCPProviderRegistry +from data_designer.engine.models.clients.types import AssistantMessage, ChatCompletionResponse, ToolCall from data_designer.engine.secret_resolver import SecretResolver -from data_designer.engine.testing.stubs import StubHuggingFaceSeedReader, StubMessage, StubResponse +from data_designer.engine.testing.stubs import StubHuggingFaceSeedReader # ============================================================================= # Seed reader fixtures @@ -151,61 +152,66 @@ def factory( # ============================================================================= -# Completion response fixtures +# Completion response fixtures (canonical ChatCompletionResponse) # ============================================================================= @pytest.fixture -def mock_completion_response_no_tools() -> StubResponse: +def mock_completion_response_no_tools() -> ChatCompletionResponse: """Mock LLM response with no tool calls.""" - return StubResponse(StubMessage(content="Hello, I can help with that.")) + return ChatCompletionResponse( + message=AssistantMessage(content="Hello, I can help with that."), + ) @pytest.fixture -def mock_completion_response_single_tool() -> StubResponse: +def mock_completion_response_single_tool() -> ChatCompletionResponse: """Mock LLM response with single tool call.""" - tool_call = { - "id": "call-1", - "type": "function", - "function": {"name": "lookup", "arguments": '{"query": "test"}'}, - } - return StubResponse(StubMessage(content="Let me look that up.", tool_calls=[tool_call])) + return ChatCompletionResponse( + message=AssistantMessage( + content="Let me look that up.", + tool_calls=[ + ToolCall(id="call-1", name="lookup", arguments_json='{"query": "test"}'), + ], + ), + ) @pytest.fixture -def mock_completion_response_parallel_tools() -> StubResponse: +def mock_completion_response_parallel_tools() -> ChatCompletionResponse: """Mock LLM response with multiple parallel tool calls.""" - tool_calls = [ - {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"query": "first"}'}}, - {"id": "call-2", "type": "function", "function": {"name": "search", "arguments": '{"term": "second"}'}}, - {"id": "call-3", "type": "function", "function": {"name": "fetch", "arguments": '{"url": "example.com"}'}}, - ] - return StubResponse(StubMessage(content="Executing multiple tools.", tool_calls=tool_calls)) + return ChatCompletionResponse( + message=AssistantMessage( + content="Executing multiple tools.", + tool_calls=[ + ToolCall(id="call-1", name="lookup", arguments_json='{"query": "first"}'), + ToolCall(id="call-2", name="search", arguments_json='{"term": "second"}'), + ToolCall(id="call-3", name="fetch", arguments_json='{"url": "example.com"}'), + ], + ), + ) @pytest.fixture -def mock_completion_response_with_reasoning() -> StubResponse: +def mock_completion_response_with_reasoning() -> ChatCompletionResponse: """Mock LLM response with reasoning_content.""" - return StubResponse( - StubMessage( + return ChatCompletionResponse( + message=AssistantMessage( content=" Final answer with extra spaces. ", reasoning_content=" Thinking about the problem... ", - ) + ), ) @pytest.fixture -def mock_completion_response_tool_with_reasoning() -> StubResponse: +def mock_completion_response_tool_with_reasoning() -> ChatCompletionResponse: """Mock LLM response with tool calls and reasoning_content.""" - tool_call = { - "id": "call-1", - "type": "function", - "function": {"name": "lookup", "arguments": '{"query": "test"}'}, - } - return StubResponse( - StubMessage( + return ChatCompletionResponse( + message=AssistantMessage( content=" Looking it up... ", - tool_calls=[tool_call], + tool_calls=[ + ToolCall(id="call-1", name="lookup", arguments_json='{"query": "test"}'), + ], reasoning_content=" I should use the lookup tool. ", - ) + ), ) diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py b/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py index 60567b51e..ce4c2fb5a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py +++ b/packages/data-designer-engine/src/data_designer/engine/testing/stubs.py @@ -8,6 +8,7 @@ from data_designer.config.base import ConfigBase, SingleColumnConfig from data_designer.engine.column_generators.generators.base import ColumnGeneratorCellByCell +from data_designer.engine.models.clients.types import AssistantMessage, ChatCompletionResponse, ToolCall from data_designer.engine.models.utils import ChatMessage from data_designer.engine.resources.seed_reader import SeedReader from data_designer.plugins.plugin import Plugin, PluginType @@ -24,7 +25,7 @@ def get_column_names(self) -> list[str]: def get_dataset_uri(self) -> str: return "unused in these tests" - def create_duckdb_connection(self): + def create_duckdb_connection(self) -> None: pass def get_seed_type(self) -> str: @@ -41,7 +42,7 @@ class ValidTestConfig(SingleColumnConfig): class ValidTestTask(ColumnGeneratorCellByCell[ValidTestConfig]): """Valid task for testing plugin creation.""" - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data @@ -70,12 +71,12 @@ class StubPluginConfigB(SingleColumnConfig): class StubPluginTaskA(ColumnGeneratorCellByCell[StubPluginConfigA]): - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data class StubPluginTaskB(ColumnGeneratorCellByCell[StubPluginConfigB]): - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data @@ -95,17 +96,17 @@ class StubPluginConfigBlobsAndSeeds(SingleColumnConfig): class StubPluginTaskModels(ColumnGeneratorCellByCell[StubPluginConfigModels]): - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data class StubPluginTaskModelsAndBlobs(ColumnGeneratorCellByCell[StubPluginConfigModelsAndBlobs]): - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data class StubPluginTaskBlobsAndSeeds(ColumnGeneratorCellByCell[StubPluginConfigBlobsAndSeeds]): - def generate(self, data: dict) -> dict: + def generate(self, data: dict[str, Any]) -> dict[str, Any]: return data @@ -135,7 +136,7 @@ def generate(self, data: dict) -> dict: # ============================================================================= -# Stub LLM response classes for testing +# Stub LLM response classes for testing (legacy, kept for backward compat) # ============================================================================= @@ -173,6 +174,26 @@ def __init__(self, message: StubMessage) -> None: self.choices = [StubChoice(message)] +# ============================================================================= +# Canonical stub helpers +# ============================================================================= + + +def make_stub_completion_response( + content: str | None = None, + reasoning_content: str | None = None, + tool_calls: list[ToolCall] | None = None, +) -> ChatCompletionResponse: + """Factory helper for creating canonical ChatCompletionResponse test objects.""" + return ChatCompletionResponse( + message=AssistantMessage( + content=content, + reasoning_content=reasoning_content, + tool_calls=tool_calls or [], + ), + ) + + # ============================================================================= # Stub MCP classes for testing tool calling # ============================================================================= @@ -195,8 +216,8 @@ def __init__( self, max_tool_call_turns: int = 3, tool_schemas: list[dict[str, Any]] | None = None, - process_fn: Callable[[Any], list[ChatMessage]] | None = None, - refuse_fn: Callable[[Any], list[ChatMessage]] | None = None, + process_fn: Callable[[ChatCompletionResponse], list[ChatMessage]] | None = None, + refuse_fn: Callable[[ChatCompletionResponse], list[ChatMessage]] | None = None, ) -> None: self.tool_alias = "tools" self.providers = ["tools"] @@ -208,34 +229,49 @@ def __init__( def get_tool_schemas(self) -> list[dict[str, Any]]: return self._tool_schemas - def tool_call_count(self, completion_response: Any) -> int: - tool_calls = getattr(completion_response.choices[0].message, "tool_calls", None) - return len(tool_calls) if tool_calls else 0 + def get_tool_call_count(self, completion_response: ChatCompletionResponse) -> int: + return len(completion_response.message.tool_calls) - def has_tool_calls(self, completion_response: Any) -> bool: - return completion_response.choices[0].message.tool_calls is not None + def has_tool_calls(self, completion_response: ChatCompletionResponse) -> bool: + return len(completion_response.message.tool_calls) > 0 - def process_completion_response(self, completion_response: Any) -> list[ChatMessage]: + def process_completion_response(self, completion_response: ChatCompletionResponse) -> list[ChatMessage]: if self._process_fn: return self._process_fn(completion_response) - message = completion_response.choices[0].message - tool_calls = message.tool_calls or [] + message = completion_response.message + tool_calls = message.tool_calls + tool_call_dicts = [ + { + "id": tc.id, + "type": "function", + "function": {"name": tc.name, "arguments": tc.arguments_json}, + } + for tc in tool_calls + ] return [ - ChatMessage.as_assistant(content=message.content or "", tool_calls=tool_calls), - *[ChatMessage.as_tool(content="tool-result", tool_call_id=tc["id"]) for tc in tool_calls], + ChatMessage.as_assistant(content=message.content or "", tool_calls=tool_call_dicts), + *[ChatMessage.as_tool(content="tool-result", tool_call_id=tc.id) for tc in tool_calls], ] - def refuse_completion_response(self, completion_response: Any) -> list[ChatMessage]: + def refuse_completion_response(self, completion_response: ChatCompletionResponse) -> list[ChatMessage]: if self._refuse_fn: return self._refuse_fn(completion_response) - message = completion_response.choices[0].message - tool_calls = message.tool_calls or [] + message = completion_response.message + tool_calls = message.tool_calls + tool_call_dicts = [ + { + "id": tc.id, + "type": "function", + "function": {"name": tc.name, "arguments": tc.arguments_json}, + } + for tc in tool_calls + ] return [ - ChatMessage.as_assistant(content="", tool_calls=tool_calls), + ChatMessage.as_assistant(content="", tool_calls=tool_call_dicts), *[ ChatMessage.as_tool( content="Tool call refused: maximum tool-calling turns reached.", - tool_call_id=tc["id"], + tool_call_id=tc.id, ) for tc in tool_calls ], @@ -243,14 +279,7 @@ def refuse_completion_response(self, completion_response: Any) -> list[ChatMessa class StubMCPRegistry: - """Stub MCP registry that returns a configurable StubMCPFacade. - - This stub provides a simple registry implementation for testing that - returns the configured StubMCPFacade instance. - - Args: - facade: The StubMCPFacade instance to return. If None, creates a default one. - """ + """Stub MCP registry that returns a configurable StubMCPFacade.""" def __init__(self, facade: StubMCPFacade | None = None) -> None: self._facade = facade or StubMCPFacade() diff --git a/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py b/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py index 610fa53f9..1300ed550 100644 --- a/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py +++ b/packages/data-designer-engine/tests/engine/mcp/test_mcp_facade.py @@ -10,65 +10,52 @@ from data_designer.config.mcp import LocalStdioMCPProvider, ToolConfig from data_designer.engine.mcp import io as mcp_io -from data_designer.engine.mcp.errors import DuplicateToolNameError, MCPToolError +from data_designer.engine.mcp.errors import DuplicateToolNameError, MCPConfigurationError, MCPToolError from data_designer.engine.mcp.facade import DEFAULT_TOOL_REFUSAL_MESSAGE, MCPFacade from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult from data_designer.engine.model_provider import MCPProviderRegistry - - -# Fake classes are used directly in tests to create custom responses -class FakeMessage: - """Fake message class for mocking LLM completion responses.""" - - def __init__( - self, - content: str | None, - tool_calls: list[dict] | None = None, - reasoning_content: str | None = None, - ) -> None: - self.content = content - self.tool_calls = tool_calls - self.reasoning_content = reasoning_content - - -class FakeChoice: - """Fake choice class for mocking LLM completion responses.""" - - def __init__(self, message: FakeMessage) -> None: - self.message = message - - -class FakeResponse: - """Fake response class for mocking LLM completion responses.""" - - def __init__(self, message: FakeMessage) -> None: - self.choices = [FakeChoice(message)] +from data_designer.engine.models.clients.types import AssistantMessage, ChatCompletionResponse, ToolCall + + +def _make_response( + content: str | None = None, + tool_calls: list[ToolCall] | None = None, + reasoning_content: str | None = None, +) -> ChatCompletionResponse: + """Shorthand for creating canonical test responses.""" + return ChatCompletionResponse( + message=AssistantMessage( + content=content, + reasoning_content=reasoning_content, + tool_calls=tool_calls or [], + ), + ) # ============================================================================= -# tool_call_count() tests +# get_tool_call_count() tests # ============================================================================= -def test_tool_call_count_no_tools(mock_completion_response_no_tools: FakeResponse) -> None: +def test_tool_call_count_no_tools(mock_completion_response_no_tools: ChatCompletionResponse) -> None: """Returns 0 when response has no tool calls.""" - assert MCPFacade.tool_call_count(mock_completion_response_no_tools) == 0 + assert MCPFacade.get_tool_call_count(mock_completion_response_no_tools) == 0 -def test_tool_call_count_single_tool(mock_completion_response_single_tool: FakeResponse) -> None: +def test_tool_call_count_single_tool(mock_completion_response_single_tool: ChatCompletionResponse) -> None: """Returns 1 for single tool call.""" - assert MCPFacade.tool_call_count(mock_completion_response_single_tool) == 1 + assert MCPFacade.get_tool_call_count(mock_completion_response_single_tool) == 1 -def test_tool_call_count_parallel_tools(mock_completion_response_parallel_tools: FakeResponse) -> None: +def test_tool_call_count_parallel_tools(mock_completion_response_parallel_tools: ChatCompletionResponse) -> None: """Returns correct count for parallel tool calls (e.g., 3).""" - assert MCPFacade.tool_call_count(mock_completion_response_parallel_tools) == 3 + assert MCPFacade.get_tool_call_count(mock_completion_response_parallel_tools) == 3 def test_tool_call_count_none_tool_calls_attribute() -> None: - """Returns 0 when tool_calls attribute is None.""" - response = FakeResponse(FakeMessage(content="Hello", tool_calls=None)) - assert MCPFacade.tool_call_count(response) == 0 + """Returns 0 when tool_calls is empty.""" + response = _make_response(content="Hello") + assert MCPFacade.get_tool_call_count(response) == 0 # ============================================================================= @@ -76,12 +63,12 @@ def test_tool_call_count_none_tool_calls_attribute() -> None: # ============================================================================= -def test_has_tool_calls_true(mock_completion_response_single_tool: FakeResponse) -> None: +def test_has_tool_calls_true(mock_completion_response_single_tool: ChatCompletionResponse) -> None: """Returns True when tool calls are present.""" assert MCPFacade.has_tool_calls(mock_completion_response_single_tool) is True -def test_has_tool_calls_false(mock_completion_response_no_tools: FakeResponse) -> None: +def test_has_tool_calls_false(mock_completion_response_no_tools: ChatCompletionResponse) -> None: """Returns False when no tool calls are present.""" assert MCPFacade.has_tool_calls(mock_completion_response_no_tools) is False @@ -93,7 +80,7 @@ def test_has_tool_calls_false(mock_completion_response_no_tools: FakeResponse) - def test_process_completion_no_tool_calls( stub_mcp_facade: MCPFacade, - mock_completion_response_no_tools: FakeResponse, + mock_completion_response_no_tools: ChatCompletionResponse, ) -> None: """Returns [assistant_message] when no tool calls present.""" messages = stub_mcp_facade.process_completion_response(mock_completion_response_no_tools) @@ -107,7 +94,7 @@ def test_process_completion_no_tool_calls( def test_process_completion_with_tool_calls( monkeypatch: pytest.MonkeyPatch, stub_mcp_facade: MCPFacade, - mock_completion_response_single_tool: FakeResponse, + mock_completion_response_single_tool: ChatCompletionResponse, ) -> None: """Returns [assistant_msg, tool_msg] for tool calls.""" @@ -137,7 +124,7 @@ def mock_call_tools( def test_process_completion_preserves_content( stub_mcp_facade: MCPFacade, - mock_completion_response_no_tools: FakeResponse, + mock_completion_response_no_tools: ChatCompletionResponse, ) -> None: """Assistant content is preserved in returned message.""" messages = stub_mcp_facade.process_completion_response(mock_completion_response_no_tools) @@ -147,7 +134,7 @@ def test_process_completion_preserves_content( def test_process_completion_preserves_reasoning_content( stub_mcp_facade: MCPFacade, - mock_completion_response_with_reasoning: FakeResponse, + mock_completion_response_with_reasoning: ChatCompletionResponse, ) -> None: """Reasoning content is preserved when present.""" messages = stub_mcp_facade.process_completion_response(mock_completion_response_with_reasoning) @@ -158,7 +145,7 @@ def test_process_completion_preserves_reasoning_content( def test_process_completion_strips_whitespace_with_reasoning( stub_mcp_facade: MCPFacade, - mock_completion_response_with_reasoning: FakeResponse, + mock_completion_response_with_reasoning: ChatCompletionResponse, ) -> None: """Content and reasoning are stripped when reasoning is present.""" messages = stub_mcp_facade.process_completion_response(mock_completion_response_with_reasoning) @@ -170,7 +157,7 @@ def test_process_completion_strips_whitespace_with_reasoning( def test_process_completion_parallel_tool_calls( monkeypatch: pytest.MonkeyPatch, stub_mcp_facade: MCPFacade, - mock_completion_response_parallel_tools: FakeResponse, + mock_completion_response_parallel_tools: ChatCompletionResponse, ) -> None: """All parallel tool calls are executed and messages returned.""" @@ -222,13 +209,10 @@ def test_process_completion_tool_not_in_allow_list( mcp_provider_registry=stub_mcp_provider_registry, ) - # Tool "forbidden" is not in allow_tools ["lookup", "search"] - tool_call = { - "id": "call-1", - "type": "function", - "function": {"name": "forbidden", "arguments": "{}"}, - } - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) + response = _make_response( + content="", + tool_calls=[ToolCall(id="call-1", name="forbidden", arguments_json="{}")], + ) with pytest.raises(MCPToolError, match="not permitted"): facade.process_completion_response(response) @@ -253,8 +237,10 @@ def mock_call_tools( monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools) - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} - response = FakeResponse(FakeMessage(content=None, tool_calls=[tool_call])) + response = _make_response( + content=None, + tool_calls=[ToolCall(id="call-1", name="lookup", arguments_json="{}")], + ) messages = stub_mcp_facade.process_completion_response(response) @@ -270,7 +256,7 @@ def mock_call_tools( def test_refuse_completion_no_tool_calls( stub_mcp_facade: MCPFacade, - mock_completion_response_no_tools: FakeResponse, + mock_completion_response_no_tools: ChatCompletionResponse, ) -> None: """Returns [assistant_message] when no tool calls to refuse.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_no_tools) @@ -282,7 +268,7 @@ def test_refuse_completion_no_tool_calls( def test_refuse_completion_single_tool( stub_mcp_facade: MCPFacade, - mock_completion_response_single_tool: FakeResponse, + mock_completion_response_single_tool: ChatCompletionResponse, ) -> None: """Returns assistant + refusal message for single tool call.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_single_tool) @@ -297,7 +283,7 @@ def test_refuse_completion_single_tool( def test_refuse_completion_parallel_tools( stub_mcp_facade: MCPFacade, - mock_completion_response_parallel_tools: FakeResponse, + mock_completion_response_parallel_tools: ChatCompletionResponse, ) -> None: """Returns assistant + refusal for each parallel tool call.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_parallel_tools) @@ -313,7 +299,7 @@ def test_refuse_completion_parallel_tools( def test_refuse_completion_default_message( stub_mcp_facade: MCPFacade, - mock_completion_response_single_tool: FakeResponse, + mock_completion_response_single_tool: ChatCompletionResponse, ) -> None: """Uses default refusal message when none provided.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_single_tool) @@ -323,7 +309,7 @@ def test_refuse_completion_default_message( def test_refuse_completion_custom_message( stub_mcp_facade: MCPFacade, - mock_completion_response_single_tool: FakeResponse, + mock_completion_response_single_tool: ChatCompletionResponse, ) -> None: """Uses custom refusal message when provided.""" custom_message = "Custom refusal: Budget exceeded." @@ -337,12 +323,11 @@ def test_refuse_completion_custom_message( def test_refuse_completion_preserves_tool_call_ids( stub_mcp_facade: MCPFacade, - mock_completion_response_parallel_tools: FakeResponse, + mock_completion_response_parallel_tools: ChatCompletionResponse, ) -> None: """Refusal messages have correct tool_call_id linkage.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_parallel_tools) - # Verify each refusal message has the correct tool_call_id assert messages[1].tool_call_id == "call-1" assert messages[2].tool_call_id == "call-2" assert messages[3].tool_call_id == "call-3" @@ -350,7 +335,7 @@ def test_refuse_completion_preserves_tool_call_ids( def test_refuse_completion_preserves_reasoning( stub_mcp_facade: MCPFacade, - mock_completion_response_tool_with_reasoning: FakeResponse, + mock_completion_response_tool_with_reasoning: ChatCompletionResponse, ) -> None: """Reasoning content preserved in refusal scenario.""" messages = stub_mcp_facade.refuse_completion_response(mock_completion_response_tool_with_reasoning) @@ -363,7 +348,7 @@ def test_refuse_completion_preserves_reasoning( def test_refuse_does_not_call_mcp_server( monkeypatch: pytest.MonkeyPatch, stub_mcp_facade: MCPFacade, - mock_completion_response_single_tool: FakeResponse, + mock_completion_response_single_tool: ChatCompletionResponse, ) -> None: """Verify MCP server is NOT called during refusal.""" call_tools_called = False @@ -484,40 +469,20 @@ def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MC monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) - from data_designer.engine.mcp.errors import MCPConfigurationError - with pytest.raises(MCPConfigurationError, match="not found"): facade.get_tool_schemas() # ============================================================================= -# Tool call normalization via public API (process_completion_response) +# Tool call handling via public API (process_completion_response) # ============================================================================= -def test_process_completion_missing_tool_name(stub_mcp_facade: MCPFacade) -> None: - """process_completion_response raises MCPToolError when tool call has no name.""" - tool_call = {"id": "call-1", "function": {"arguments": "{}"}} # Missing name - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) - - with pytest.raises(MCPToolError, match="missing a tool name"): - stub_mcp_facade.process_completion_response(response) - - -def test_process_completion_invalid_json_arguments(stub_mcp_facade: MCPFacade) -> None: - """process_completion_response raises MCPToolError when arguments are invalid JSON.""" - tool_call = {"id": "call-1", "function": {"name": "lookup", "arguments": "not valid json"}} - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) - - with pytest.raises(MCPToolError, match="Invalid tool arguments"): - stub_mcp_facade.process_completion_response(response) - - -def test_process_completion_dict_arguments( +def test_process_completion_with_empty_arguments( monkeypatch: pytest.MonkeyPatch, stub_mcp_facade: MCPFacade, ) -> None: - """process_completion_response handles dict arguments correctly.""" + """process_completion_response handles empty arguments gracefully.""" def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),) @@ -536,21 +501,22 @@ def mock_call_tools( monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools) - # Pass dict arguments (not JSON string) - tool_call = {"id": "call-1", "function": {"name": "lookup", "arguments": {"query": "test"}}} - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) + response = _make_response( + content="", + tool_calls=[ToolCall(id="call-1", name="lookup", arguments_json="{}")], + ) messages = stub_mcp_facade.process_completion_response(response) assert len(messages) == 2 - assert captured_args[0] == {"query": "test"} + assert captured_args[0] == {} -def test_process_completion_empty_arguments( +def test_process_completion_with_dict_arguments( monkeypatch: pytest.MonkeyPatch, stub_mcp_facade: MCPFacade, ) -> None: - """process_completion_response handles None/empty arguments gracefully.""" + """process_completion_response handles arguments via canonical ToolCall correctly.""" def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),) @@ -569,85 +535,15 @@ def mock_call_tools( monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools) - tool_call = {"id": "call-1", "function": {"name": "lookup", "arguments": None}} - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) - - messages = stub_mcp_facade.process_completion_response(response) - - assert len(messages) == 2 - assert captured_args[0] == {} # Empty dict for None arguments - - -def test_process_completion_generates_tool_call_id( - monkeypatch: pytest.MonkeyPatch, - stub_mcp_facade: MCPFacade, -) -> None: - """process_completion_response generates UUID for tool calls without ID.""" - - def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: - return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),) - - def mock_call_tools( - calls: list[tuple[Any, str, dict[str, Any]]], - *, - timeout_sec: float | None = None, - ) -> list[MCPToolResult]: - return [MCPToolResult(content="result") for _ in calls] - - monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) - monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools) - - # Tool call without id - tool_call = {"function": {"name": "lookup", "arguments": "{}"}} - response = FakeResponse(FakeMessage(content="", tool_calls=[tool_call])) - - messages = stub_mcp_facade.process_completion_response(response) - - # Should have generated an ID - assert len(messages) == 2 - assert messages[1].tool_call_id is not None - assert len(messages[1].tool_call_id) == 32 # UUID hex format - - -def test_process_completion_object_format_tool_calls( - monkeypatch: pytest.MonkeyPatch, - stub_mcp_facade: MCPFacade, -) -> None: - """process_completion_response handles object format tool calls.""" - - def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: - return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),) - - captured_calls: list[tuple[str, dict[str, Any]]] = [] - - def mock_call_tools( - calls: list[tuple[Any, str, dict[str, Any]]], - *, - timeout_sec: float | None = None, - ) -> list[MCPToolResult]: - for _, tool_name, args in calls: - captured_calls.append((tool_name, args)) - return [MCPToolResult(content="result") for _ in calls] - - monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) - monkeypatch.setattr(mcp_io, "call_tools", mock_call_tools) - - # Create object format tool call (simulating what some LLM libraries return) - class FakeFunction: - name = "lookup" - arguments = '{"query": "test"}' - - class FakeToolCall: - id = "call-obj-1" - function = FakeFunction() - - response = FakeResponse(FakeMessage(content="", tool_calls=[FakeToolCall()])) + response = _make_response( + content="", + tool_calls=[ToolCall(id="call-1", name="lookup", arguments_json='{"query": "test"}')], + ) messages = stub_mcp_facade.process_completion_response(response) assert len(messages) == 2 - assert captured_calls[0] == ("lookup", {"query": "test"}) - assert messages[1].tool_call_id == "call-obj-1" + assert captured_args[0] == {"query": "test"} # ============================================================================= @@ -771,7 +667,6 @@ def test_get_tool_schemas_duplicate_tool_names_raises_error( ) def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: - # Both providers have a tool named "lookup" if provider.name == "tools": return ( MCPToolDefinition(name="lookup", description="Lookup from tools", input_schema={"type": "object"}), @@ -804,7 +699,6 @@ def test_get_tool_schemas_duplicate_tool_names_reports_all_duplicates( ) def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: - # Both providers have "lookup" and "search" as duplicates if provider.name == "tools": return ( MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}), @@ -820,7 +714,6 @@ def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MC with pytest.raises(DuplicateToolNameError) as exc_info: facade.get_tool_schemas() - # Both duplicates should be reported assert "lookup" in str(exc_info.value) assert "search" in str(exc_info.value) @@ -841,14 +734,12 @@ def test_get_tool_schemas_no_duplicates_passes( ) def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MCPToolDefinition, ...]: - # Each provider has unique tool names if provider.name == "tools": return (MCPToolDefinition(name="lookup", description="Lookup", input_schema={"type": "object"}),) return (MCPToolDefinition(name="fetch", description="Fetch", input_schema={"type": "object"}),) monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) - # Should not raise schemas = facade.get_tool_schemas() assert len(schemas) == 2 @@ -867,6 +758,5 @@ def mock_list_tools(provider: Any, timeout_sec: float | None = None) -> tuple[MC monkeypatch.setattr(mcp_io, "list_tools", mock_list_tools) - # Should not raise schemas = stub_mcp_facade.get_tool_schemas() assert len(schemas) == 2 diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py index 194884f13..62828039d 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_client_errors.py @@ -10,6 +10,7 @@ from data_designer.engine.models.clients.errors import ( ProviderError, ProviderErrorKind, + extract_message_from_exception_string, map_http_error_to_provider_error, map_http_status_to_provider_error_kind, ) @@ -207,3 +208,46 @@ def test_map_http_error_retry_after_returns_none_for_garbage() -> None: ) error = map_http_error_to_provider_error(response=response, provider_name="stub-provider") assert error.retry_after is None + + +@pytest.mark.parametrize( + "raw,expected", + [ + ( + "Error code: 400 - {'error': {'message': 'Context length exceeded', 'type': 'invalid_request_error'}}".replace( + "'", '"' + ), + "Context length exceeded", + ), + ( + 'Error code: 403 - {"error": "Insufficient permissions"}', + "Insufficient permissions", + ), + ( + 'Error code: 500 - {"message": "Internal failure"}', + "Internal failure", + ), + ( + 'Error code: 422 - {"detail": "Unprocessable entity"}', + "Unprocessable entity", + ), + ( + "Connection timed out", + "Connection timed out", + ), + ( + "Error code: 400 - {not valid json", + "Error code: 400 - {not valid json", + ), + ], + ids=[ + "nested-error-message", + "top-level-error-string", + "top-level-message-string", + "top-level-detail-string", + "no-json-passthrough", + "malformed-json-passthrough", + ], +) +def test_extract_message_from_exception_string(raw: str, expected: str) -> None: + assert extract_message_from_exception_string(raw) == expected diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py index 7c9b8db9a..c95e7c070 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_litellm_bridge.py @@ -58,13 +58,13 @@ def test_completion_maps_canonical_fields_from_litellm_response( mock_router.completion.assert_called_once_with( model="stub-model", messages=[{"role": "user", "content": "hello"}], + extra_headers={"x-trace": "1"}, tools=[{"type": "function", "function": {"name": "lookup"}}], temperature=0.2, top_p=0.8, max_tokens=256, - extra_body={"foo": "bar"}, - extra_headers={"x-trace": "1"}, metadata={"trace_id": "abc"}, + foo="bar", ) @@ -84,6 +84,7 @@ async def test_acompletion_maps_canonical_fields_from_litellm_response( mock_router.acompletion.assert_awaited_once_with( model="stub-model", messages=[{"role": "user", "content": "hello"}], + extra_headers=None, ) @@ -108,6 +109,7 @@ def test_embeddings_maps_vectors_and_usage( mock_router.embedding.assert_called_once_with( model="stub-model", input=["a", "b"], + extra_headers=None, encoding_format="float", dimensions=32, ) @@ -148,6 +150,7 @@ def test_generate_image_uses_chat_completion_path_when_messages_provided( mock_router.completion.assert_called_once_with( model="stub-model", messages=messages, + extra_headers=None, n=1, ) mock_router.image_generation.assert_not_called() @@ -176,7 +179,9 @@ def test_generate_image_uses_diffusion_path_without_messages( assert result.usage.output_tokens == 12 assert result.usage.total_tokens == 21 assert result.usage.generated_images == 2 - mock_router.image_generation.assert_called_once_with(prompt="make an image", model="stub-model", n=2) + mock_router.image_generation.assert_called_once_with( + prompt="make an image", model="stub-model", extra_headers=None, n=2 + ) @pytest.mark.asyncio @@ -197,7 +202,7 @@ async def test_aembeddings_maps_vectors_and_usage( assert result.usage is not None assert result.usage.input_tokens == 5 assert result.raw is response - mock_router.aembedding.assert_awaited_once_with(model="stub-model", input=["x", "y"]) + mock_router.aembedding.assert_awaited_once_with(model="stub-model", input=["x", "y"], extra_headers=None) def test_completion_coerces_list_content_blocks_to_string( @@ -245,7 +250,9 @@ async def test_agenerate_image_uses_diffusion_path_without_messages( assert result.images[0].b64_data == "YXN5bmM=" assert result.usage is not None assert result.usage.generated_images == 1 - mock_router.aimage_generation.assert_awaited_once_with(prompt="async image", model="stub-model", n=1) + mock_router.aimage_generation.assert_awaited_once_with( + prompt="async image", model="stub-model", extra_headers=None, n=1 + ) def test_completion_with_empty_choices_returns_empty_message( @@ -300,7 +307,7 @@ def test_completion_wraps_router_exception_with_status_code( assert exc_info.value.kind == ProviderErrorKind.RATE_LIMIT assert exc_info.value.status_code == 429 assert exc_info.value.provider_name == "stub-provider" - assert exc_info.value.cause is exc + assert exc_info.value.__cause__ is exc def test_completion_wraps_generic_router_exception( diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py new file mode 100644 index 000000000..d48502431 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/clients/test_parsing.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from data_designer.engine.models.clients.parsing import extract_tool_calls +from data_designer.engine.models.clients.types import ( + ChatCompletionRequest, + EmbeddingRequest, + ImageGenerationRequest, + TransportKwargs, +) + +# --- TransportKwargs.from_request: extra_body flattening --- + + +def test_extra_body_keys_are_flattened_into_body() -> None: + request = ChatCompletionRequest( + model="m", + messages=[], + temperature=0.7, + extra_body={"reasoning_effort": "high", "seed": 42}, + ) + transport = TransportKwargs.from_request(request) + + assert transport.body["temperature"] == 0.7 + assert transport.body["reasoning_effort"] == "high" + assert transport.body["seed"] == 42 + assert "extra_body" not in transport.body + + +def test_extra_body_none_produces_no_extra_keys() -> None: + request = ChatCompletionRequest(model="m", messages=[], temperature=0.5) + transport = TransportKwargs.from_request(request) + + assert transport.body == {"temperature": 0.5} + assert "extra_body" not in transport.body + + +def test_extra_body_empty_dict_produces_no_extra_keys() -> None: + request = ChatCompletionRequest(model="m", messages=[], extra_body={}) + transport = TransportKwargs.from_request(request) + + assert "extra_body" not in transport.body + + +# --- TransportKwargs.from_request: extra_headers separation --- + + +def test_extra_headers_are_separated_into_headers() -> None: + request = ChatCompletionRequest( + model="m", + messages=[], + extra_headers={"X-Custom": "value", "Authorization": "Bearer tok"}, + ) + transport = TransportKwargs.from_request(request) + + assert transport.headers == {"X-Custom": "value", "Authorization": "Bearer tok"} + assert "extra_headers" not in transport.body + + +def test_extra_headers_none_produces_empty_headers() -> None: + request = ChatCompletionRequest(model="m", messages=[]) + transport = TransportKwargs.from_request(request) + + assert transport.headers == {} + + +# --- TransportKwargs.from_request: combined --- + + +def test_extra_body_and_headers_together() -> None: + request = ChatCompletionRequest( + model="m", + messages=[], + temperature=0.9, + max_tokens=100, + extra_body={"seed": 1}, + extra_headers={"X-Req-Id": "abc"}, + ) + transport = TransportKwargs.from_request(request) + + assert transport.body == {"temperature": 0.9, "max_tokens": 100, "seed": 1} + assert transport.headers == {"X-Req-Id": "abc"} + + +# --- TransportKwargs.from_request: exclude parameter --- + + +def test_exclude_removes_fields_from_body() -> None: + request = ImageGenerationRequest( + model="m", + prompt="draw a cat", + messages=[{"role": "user", "content": "hi"}], + n=2, + extra_body={"quality": "hd"}, + ) + transport = TransportKwargs.from_request(request, exclude=frozenset({"messages", "prompt"})) + + assert "messages" not in transport.body + assert "prompt" not in transport.body + assert transport.body["n"] == 2 + assert transport.body["quality"] == "hd" + + +# --- TransportKwargs.from_request: works with all request types --- + + +def test_embedding_request() -> None: + request = EmbeddingRequest( + model="m", + inputs=["hello"], + extra_body={"input_type": "query"}, + extra_headers={"X-Api-Version": "2"}, + ) + transport = TransportKwargs.from_request(request) + + assert transport.body["input_type"] == "query" + assert transport.headers == {"X-Api-Version": "2"} + assert "extra_body" not in transport.body + assert "extra_headers" not in transport.body + + +def test_image_generation_request() -> None: + request = ImageGenerationRequest( + model="m", + prompt="sunset", + n=3, + extra_body={"size": "1024x1024"}, + ) + transport = TransportKwargs.from_request(request) + + assert transport.body["n"] == 3 + assert transport.body["size"] == "1024x1024" + assert transport.headers == {} + + +# --- TransportKwargs: falsy headers --- + + +def test_transport_kwargs_empty_headers_is_falsy() -> None: + tk = TransportKwargs(body={"a": 1}, headers={}) + assert not tk.headers + + +@pytest.mark.parametrize( + ("extra_body", "expected_body_keys"), + [ + (None, set()), + ({}, set()), + ({"a": 1}, {"a"}), + ({"a": 1, "b": 2}, {"a", "b"}), + ], +) +def test_extra_body_variations(extra_body: dict | None, expected_body_keys: set[str]) -> None: + request = ChatCompletionRequest(model="m", messages=[], extra_body=extra_body) + transport = TransportKwargs.from_request(request) + + assert expected_body_keys.issubset(transport.body.keys()) + assert "extra_body" not in transport.body + + +# --- extract_tool_calls --- + + +def _make_raw_tool_call( + tool_id: str | None = "call-1", + name: str = "lookup", + arguments: str = '{"q": "test"}', +) -> dict: + tc: dict = {"type": "function", "function": {"name": name, "arguments": arguments}} + if tool_id is not None: + tc["id"] = tool_id + return tc + + +def test_extract_tool_calls_basic() -> None: + raw = [_make_raw_tool_call()] + result = extract_tool_calls(raw) + + assert len(result) == 1 + assert result[0].id == "call-1" + assert result[0].name == "lookup" + assert result[0].arguments_json == '{"q": "test"}' + + +@pytest.mark.parametrize("tool_id", [None, ""], ids=["missing_id", "empty_string_id"]) +def test_extract_tool_calls_falsy_id_generates_uuid(tool_id: str | None) -> None: + raw = [_make_raw_tool_call(tool_id=tool_id)] + result = extract_tool_calls(raw) + + assert len(result) == 1 + assert len(result[0].id) == 32 # uuid4().hex length + assert result[0].id.isalnum() + + +def test_extract_tool_calls_multiple_missing_ids_are_unique() -> None: + raw = [_make_raw_tool_call(tool_id=None), _make_raw_tool_call(tool_id=None)] + result = extract_tool_calls(raw) + + assert result[0].id != result[1].id + + +@pytest.mark.parametrize("raw_input", [None, []], ids=["none", "empty_list"]) +def test_extract_tool_calls_empty_input(raw_input: list | None) -> None: + assert extract_tool_calls(raw_input) == [] + + +def test_extract_tool_calls_none_arguments() -> None: + raw = [{"id": "call-1", "function": {"name": "lookup", "arguments": None}}] + result = extract_tool_calls(raw) + + assert result[0].arguments_json == "{}" diff --git a/packages/data-designer-engine/tests/engine/models/conftest.py b/packages/data-designer-engine/tests/engine/models/conftest.py index 601dbb4ad..3f065217b 100644 --- a/packages/data-designer-engine/tests/engine/models/conftest.py +++ b/packages/data-designer-engine/tests/engine/models/conftest.py @@ -1,7 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + from pathlib import Path +from unittest.mock import MagicMock import pytest @@ -11,6 +14,7 @@ ModelConfig, ) from data_designer.engine.model_provider import ModelProvider, ModelProviderRegistry +from data_designer.engine.models.clients.base import ModelClient from data_designer.engine.models.factory import create_model_registry from data_designer.engine.models.registry import ModelRegistry from data_designer.engine.secret_resolver import SecretsFileResolver @@ -68,7 +72,11 @@ def stub_model_configs() -> list[ModelConfig]: @pytest.fixture -def stub_model_registry(stub_model_configs, stub_secrets_resolver, stub_model_provider_registry) -> ModelRegistry: +def stub_model_registry( + stub_model_configs: list[ModelConfig], + stub_secrets_resolver: SecretsFileResolver, + stub_model_provider_registry: ModelProviderRegistry, +) -> ModelRegistry: return create_model_registry( model_configs=stub_model_configs, secret_resolver=stub_secrets_resolver, @@ -76,6 +84,12 @@ def stub_model_registry(stub_model_configs, stub_secrets_resolver, stub_model_pr ) +@pytest.fixture +def stub_model_client() -> MagicMock: + """Mock ModelClient for testing ModelFacade without a real LiteLLM router.""" + return MagicMock(spec=ModelClient) + + @pytest.fixture def stub_mcp_facade_for_model() -> StubMCPFacade: """Default stub MCP facade with max_tool_call_turns=3.""" diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 662cbc762..ae92682a0 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -3,33 +3,41 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest -import data_designer.lazy_heavy_imports as lazy from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError +from data_designer.engine.models.clients.types import ( + ChatCompletionResponse, + EmbeddingResponse, + ImageGenerationResponse, + ImagePayload, + ToolCall, +) from data_designer.engine.models.errors import ImageGenerationError, ModelGenerationValidationFailureError -from data_designer.engine.models.facade import CustomRouter, ModelFacade +from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.parsers.errors import ParserException from data_designer.engine.models.utils import ChatMessage -from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, StubMessage, StubResponse - -if TYPE_CHECKING: - from litellm.types.utils import EmbeddingResponse, ModelResponse +from data_designer.engine.testing import StubMCPFacade, StubMCPRegistry, make_stub_completion_response -def mock_oai_response_object(response_text: str) -> StubResponse: - return StubResponse(StubMessage(content=response_text)) +def _make_response(content: str | None = None, **kwargs: Any) -> ChatCompletionResponse: + """Shorthand for creating a ChatCompletionResponse in tests.""" + return make_stub_completion_response(content=content, **kwargs) @pytest.fixture -def stub_model_facade(stub_model_configs, stub_secrets_resolver, stub_model_provider_registry): +def stub_model_facade( + stub_model_configs: list[Any], + stub_model_client: MagicMock, + stub_model_provider_registry: Any, +) -> ModelFacade: return ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, ) @@ -38,18 +46,6 @@ def stub_completion_messages() -> list[ChatMessage]: return [ChatMessage.as_user("test")] -@pytest.fixture -def stub_expected_completion_response(): - return lazy.litellm.types.utils.ModelResponse( - choices=lazy.litellm.types.utils.Choices(message=lazy.litellm.types.utils.Message(content="Test response")) - ) - - -@pytest.fixture -def stub_expected_embedding_response(): - return lazy.litellm.types.utils.EmbeddingResponse(data=[{"embedding": [0.1, 0.2, 0.3]}] * 2) - - @pytest.mark.parametrize( "max_correction_steps,max_conversation_restarts,total_calls", [ @@ -69,7 +65,7 @@ def test_generate( max_conversation_restarts: int, total_calls: int, ) -> None: - bad_response = mock_oai_response_object("bad response") + bad_response = _make_response("bad response") mock_completion.side_effect = lambda *args, **kwargs: bad_response def _failing_parser(response: str) -> str: @@ -110,14 +106,11 @@ def test_generate_with_system_prompt( system_prompt: str, expected_messages: list[ChatMessage], ) -> None: - # Capture messages at call time since they get mutated after the call captured_messages = [] - def capture_and_return(*args: Any, **kwargs: Any) -> ModelResponse: - captured_messages.append(list(args[1])) # Copy the messages list - return lazy.litellm.types.utils.ModelResponse( - choices=lazy.litellm.types.utils.Choices(message=lazy.litellm.types.utils.Message(content="Hello!")) - ) + def capture_and_return(*args: Any, **kwargs: Any) -> ChatCompletionResponse: + captured_messages.append(list(args[1])) + return _make_response("Hello!") mock_completion.side_effect = capture_and_return @@ -143,21 +136,21 @@ def test_generate_strips_response_content( expected: str, ) -> None: """Response content from the LLM is stripped of leading/trailing whitespace.""" - mock_completion.side_effect = lambda *args, **kwargs: StubResponse(StubMessage(content=raw_content)) + mock_completion.side_effect = lambda *args, **kwargs: _make_response(raw_content) result, _ = stub_model_facade.generate(prompt="test", parser=lambda x: x) assert result == expected -def test_model_alias_property(stub_model_facade, stub_model_configs): +def test_model_alias_property(stub_model_facade: ModelFacade, stub_model_configs: list[Any]) -> None: assert stub_model_facade.model_alias == stub_model_configs[0].alias -def test_usage_stats_property(stub_model_facade): +def test_usage_stats_property(stub_model_facade: ModelFacade) -> None: assert stub_model_facade.usage_stats is not None assert hasattr(stub_model_facade.usage_stats, "model_dump") -def test_consolidate_kwargs(stub_model_configs, stub_model_facade): +def test_consolidate_kwargs(stub_model_configs: list[Any], stub_model_facade: ModelFacade) -> None: # Model config generate kwargs are used as base, and purpose is removed result = stub_model_facade.consolidate_kwargs(purpose="test") assert result == stub_model_configs[0].inference_parameters.generate_kwargs @@ -191,126 +184,105 @@ def test_consolidate_kwargs(stub_model_configs, stub_model_facade): True, ], ) -@patch.object(CustomRouter, "completion", autospec=True) def test_completion_success( - mock_router_completion: Any, stub_completion_messages: list[ChatMessage], stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_completion_response: ModelResponse, + stub_model_client: MagicMock, skip_usage_tracking: bool, ) -> None: - mock_router_completion.side_effect = lambda self, model, messages, **kwargs: stub_expected_completion_response + expected_response = _make_response("Test response") + stub_model_client.completion.return_value = expected_response result = stub_model_facade.completion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking) - expected_messages = [message.to_dict() for message in stub_completion_messages] - assert result == stub_expected_completion_response - assert mock_router_completion.call_count == 1 - assert mock_router_completion.call_args[1] == { - "model": "stub-model-text", - "messages": expected_messages, - **stub_model_configs[0].inference_parameters.generate_kwargs, - } + assert result == expected_response + assert stub_model_client.completion.call_count == 1 -@patch.object(CustomRouter, "completion", autospec=True) def test_completion_with_exception( - mock_router_completion: Any, stub_completion_messages: list[ChatMessage], stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: - mock_router_completion.side_effect = Exception("Router error") + stub_model_client.completion.side_effect = Exception("Router error") with pytest.raises(Exception, match="Router error"): stub_model_facade.completion(stub_completion_messages) -@patch.object(CustomRouter, "completion", autospec=True) def test_completion_with_kwargs( - mock_router_completion: Any, stub_completion_messages: list[ChatMessage], stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_completion_response: ModelResponse, + stub_model_client: MagicMock, ) -> None: - captured_kwargs = {} - - def mock_completion(self: Any, model: str, messages: list[dict[str, Any]], **kwargs: Any) -> ModelResponse: - captured_kwargs.update(kwargs) - return stub_expected_completion_response - - mock_router_completion.side_effect = mock_completion + expected_response = _make_response("Test response") + stub_model_client.completion.return_value = expected_response kwargs = {"temperature": 0.7, "max_tokens": 100} result = stub_model_facade.completion(stub_completion_messages, **kwargs) - assert result == stub_expected_completion_response - # completion kwargs overrides model config generate kwargs - assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} + assert result == expected_response + assert stub_model_client.completion.call_count == 1 -@patch.object(CustomRouter, "embedding", autospec=True) def test_generate_text_embeddings_success( - mock_router_embedding: Any, stub_model_facade: ModelFacade, - stub_expected_embedding_response: EmbeddingResponse, + stub_model_client: MagicMock, ) -> None: - mock_router_embedding.side_effect = lambda self, model, input, **kwargs: stub_expected_embedding_response + expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + stub_model_client.embeddings.return_value = EmbeddingResponse(vectors=expected_vectors) input_texts = ["test1", "test2"] result = stub_model_facade.generate_text_embeddings(input_texts) - assert result == [data["embedding"] for data in stub_expected_embedding_response.data] + assert result == expected_vectors -@patch.object(CustomRouter, "embedding", autospec=True) -def test_generate_text_embeddings_with_exception(mock_router_embedding: Any, stub_model_facade: ModelFacade) -> None: - mock_router_embedding.side_effect = Exception("Router error") +def test_generate_text_embeddings_with_exception( + stub_model_facade: ModelFacade, + stub_model_client: MagicMock, +) -> None: + stub_model_client.embeddings.side_effect = Exception("Router error") with pytest.raises(Exception, match="Router error"): stub_model_facade.generate_text_embeddings(["test1", "test2"]) -@patch.object(CustomRouter, "embedding", autospec=True) def test_generate_text_embeddings_with_kwargs( - mock_router_embedding: Any, stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_embedding_response: EmbeddingResponse, + stub_model_client: MagicMock, ) -> None: - captured_kwargs = {} - - def mock_embedding(self: Any, model: str, input: list[str], **kwargs: Any) -> EmbeddingResponse: - captured_kwargs.update(kwargs) - return stub_expected_embedding_response - - mock_router_embedding.side_effect = mock_embedding + expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + stub_model_client.embeddings.return_value = EmbeddingResponse(vectors=expected_vectors) kwargs = {"temperature": 0.7, "max_tokens": 100, "input_type": "query"} _ = stub_model_facade.generate_text_embeddings(["test1", "test2"], **kwargs) - assert captured_kwargs == {**stub_model_configs[0].inference_parameters.generate_kwargs, **kwargs} + assert stub_model_client.embeddings.call_count == 1 def test_generate_with_mcp_tools( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: - tool_call = { - "id": "call-1", - "type": "function", - "function": {"name": "lookup", "arguments": '{"query": "foo"}'}, - } + tool_call = ToolCall(id="call-1", name="lookup", arguments_json='{"query": "foo"}') responses = [ - StubResponse(StubMessage(content=None, tool_calls=[tool_call])), - StubResponse(StubMessage(content="final result")), + _make_response(content=None, tool_calls=[tool_call]), + _make_response("final result"), ] captured_calls: list[tuple[list[ChatMessage], dict[str, Any]]] = [] registry_calls: list[tuple[str, str, dict[str, str], None]] = [] - def process_with_tracking(completion_response: Any) -> list[ChatMessage]: - message = completion_response.choices[0].message + def process_with_tracking(completion_response: ChatCompletionResponse) -> list[ChatMessage]: + message = completion_response.message if not message.tool_calls: return [ChatMessage.as_assistant(content=message.content or "")] registry_calls.append(("tools", "lookup", {"query": "foo"}, None)) + tc_dict = { + "id": "call-1", + "type": "function", + "function": {"name": "lookup", "arguments": '{"query": "foo"}'}, + } return [ - ChatMessage.as_assistant(content="", tool_calls=[tool_call]), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool(content="tool-output", tool_call_id="call-1"), ] @@ -325,14 +297,14 @@ def process_with_tracking(completion_response: Any) -> list[ChatMessage]: ) registry = StubMCPRegistry(facade) - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: captured_calls.append((messages, kwargs)) return responses.pop(0) model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -348,12 +320,12 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_with_tools_missing_registry( - stub_model_configs: Any, stub_secrets_resolver: Any, stub_model_provider_registry: Any + stub_model_configs: Any, stub_model_client: MagicMock, stub_model_provider_registry: Any ) -> None: model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=None, ) @@ -368,32 +340,32 @@ def test_generate_with_tools_missing_registry( def test_generate_with_tool_alias_multiple_turns( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Multiple tool call turns before final response.""" - tool_call_1 = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"query": "foo"}'}} - tool_call_2 = {"id": "call-2", "type": "function", "function": {"name": "search", "arguments": '{"term": "bar"}'}} + tool_call_1 = ToolCall(id="call-1", name="lookup", arguments_json='{"query": "foo"}') + tool_call_2 = ToolCall(id="call-2", name="search", arguments_json='{"term": "bar"}') responses = [ - StubResponse(StubMessage(content="First lookup", tool_calls=[tool_call_1])), - StubResponse(StubMessage(content="Second search", tool_calls=[tool_call_2])), - StubResponse(StubMessage(content="final result after two tool turns")), + _make_response("First lookup", tool_calls=[tool_call_1]), + _make_response("Second search", tool_calls=[tool_call_2]), + _make_response("final result after two tool turns"), ] call_count = 0 facade = StubMCPFacade(max_tool_call_turns=5) registry = StubMCPRegistry(facade) - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal call_count call_count += 1 return responses.pop(0) model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -406,29 +378,29 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_with_tools_tracks_usage_stats( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Tool usage stats are properly tracked with generations_with_tools incremented.""" - tool_call_1 = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"query": "foo"}'}} - tool_call_2 = {"id": "call-2", "type": "function", "function": {"name": "search", "arguments": '{"term": "bar"}'}} + tool_call_1 = ToolCall(id="call-1", name="lookup", arguments_json='{"query": "foo"}') + tool_call_2 = ToolCall(id="call-2", name="search", arguments_json='{"term": "bar"}') responses = [ - StubResponse(StubMessage(content="First lookup", tool_calls=[tool_call_1])), - StubResponse(StubMessage(content="Second search", tool_calls=[tool_call_2])), - StubResponse(StubMessage(content="final result")), + _make_response("First lookup", tool_calls=[tool_call_1]), + _make_response("Second search", tool_calls=[tool_call_2]), + _make_response("final result"), ] facade = StubMCPFacade(max_tool_call_turns=5) registry = StubMCPRegistry(facade) - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: return responses.pop(0) model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -444,15 +416,15 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe assert result == "final result" # Verify tool usage stats are tracked correctly - assert model.usage_stats.tool_usage.total_tool_calls == 2 # 2 tool calls total - assert model.usage_stats.tool_usage.total_tool_call_turns == 2 # 2 turns with tool calls - assert model.usage_stats.tool_usage.total_generations == 1 # 1 generation - assert model.usage_stats.tool_usage.generations_with_tools == 1 # 1 generation with tools + assert model.usage_stats.tool_usage.total_tool_calls == 2 + assert model.usage_stats.tool_usage.total_tool_call_turns == 2 + assert model.usage_stats.tool_usage.total_generations == 1 + assert model.usage_stats.tool_usage.generations_with_tools == 1 def test_generate_with_tools_tracks_multiple_generations( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Tool usage is correctly tracked across multiple generations.""" @@ -461,35 +433,35 @@ def test_generate_with_tools_tracks_multiple_generations( model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) # Generation 1: 2 tool calls across 1 turn - tool_call_a = {"id": "call-a", "type": "function", "function": {"name": "lookup", "arguments": '{"q": "1"}'}} - tool_call_b = {"id": "call-b", "type": "function", "function": {"name": "lookup", "arguments": '{"q": "2"}'}} + tool_call_a = ToolCall(id="call-a", name="lookup", arguments_json='{"q": "1"}') + tool_call_b = ToolCall(id="call-b", name="lookup", arguments_json='{"q": "2"}') responses_gen1 = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call_a, tool_call_b])), - StubResponse(StubMessage(content="result 1")), + _make_response("", tool_calls=[tool_call_a, tool_call_b]), + _make_response("result 1"), ] - def _completion_gen1(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion_gen1(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: return responses_gen1.pop(0) with patch.object(ModelFacade, "completion", new=_completion_gen1): model.generate(prompt="q1", parser=lambda x: x, tool_alias="tools") # Generation 2: 4 tool calls across 2 turns - tool_call_c = {"id": "call-c", "type": "function", "function": {"name": "search", "arguments": '{"q": "3"}'}} - tool_call_d = {"id": "call-d", "type": "function", "function": {"name": "search", "arguments": '{"q": "4"}'}} + tool_call_c = ToolCall(id="call-c", name="search", arguments_json='{"q": "3"}') + tool_call_d = ToolCall(id="call-d", name="search", arguments_json='{"q": "4"}') responses_gen2 = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call_a, tool_call_b])), - StubResponse(StubMessage(content="", tool_calls=[tool_call_c, tool_call_d])), - StubResponse(StubMessage(content="result 2")), + _make_response("", tool_calls=[tool_call_a, tool_call_b]), + _make_response("", tool_calls=[tool_call_c, tool_call_d]), + _make_response("result 2"), ] - def _completion_gen2(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion_gen2(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: return responses_gen2.pop(0) with patch.object(ModelFacade, "completion", new=_completion_gen2): @@ -497,10 +469,10 @@ def _completion_gen2(self: Any, messages: list[ChatMessage], **kwargs: Any) -> S # Generation 3: No tool calls responses_gen3 = [ - StubResponse(StubMessage(content="result 3")), + _make_response("result 3"), ] - def _completion_gen3(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion_gen3(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: return responses_gen3.pop(0) with patch.object(ModelFacade, "completion", new=_completion_gen3): @@ -515,37 +487,36 @@ def _completion_gen3(self: Any, messages: list[ChatMessage], **kwargs: Any) -> S def test_generate_tool_turn_limit_triggers_refusal( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """When max_tool_call_turns exceeded, refusal is used.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") - # Keep returning tool calls to exceed the limit responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Turn 1 - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Turn 2 (max) - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Turn 3 (exceeds, should refuse) - StubResponse(StubMessage(content="final answer after refusal")), + _make_response("", tool_calls=[tool_call]), # Turn 1 + _make_response("", tool_calls=[tool_call]), # Turn 2 (max) + _make_response("", tool_calls=[tool_call]), # Turn 3 (exceeds, should refuse) + _make_response("final answer after refusal"), ] process_calls = 0 refuse_calls = 0 - def custom_process_fn(completion_response: Any) -> list[ChatMessage]: + tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + + def custom_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: nonlocal process_calls process_calls += 1 - message = completion_response.choices[0].message return [ - ChatMessage.as_assistant(content="", tool_calls=message.tool_calls or []), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool(content="tool-result", tool_call_id="call-1"), ] - def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: + def custom_refuse_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: nonlocal refuse_calls refuse_calls += 1 - message = completion_response.choices[0].message return [ - ChatMessage.as_assistant(content="", tool_calls=message.tool_calls or []), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool(content="REFUSED: Budget exceeded", tool_call_id="call-1"), ] @@ -554,7 +525,7 @@ def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -562,8 +533,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -577,20 +548,21 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_tool_turn_limit_model_responds_after_refusal( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Model provides final answer after refusal message.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") + tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Exceeds on first turn - StubResponse(StubMessage(content="I understand, here is the answer without tools")), + _make_response("", tool_calls=[tool_call]), # Exceeds on first turn + _make_response("I understand, here is the answer without tools"), ] - def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: + def custom_refuse_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: return [ - ChatMessage.as_assistant(content="", tool_calls=[tool_call]), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool( content="Tool call refused: You have reached the maximum number of tool-calling turns.", tool_call_id="call-1", @@ -606,7 +578,7 @@ def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -614,8 +586,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -629,20 +601,20 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_tool_alias_not_in_registry( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Raises error when tool_alias not found in MCPRegistry.""" - class StubMCPRegistry: + class _StubMCPRegistry: def get_mcp(self, *, tool_alias: str) -> Any: raise ValueError(f"No tool config with alias {tool_alias!r} found!") model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, - mcp_registry=StubMCPRegistry(), + client=stub_model_client, + mcp_registry=_StubMCPRegistry(), ) with pytest.raises(MCPConfigurationError, match="not registered"): @@ -651,27 +623,27 @@ def get_mcp(self, *, tool_alias: str) -> Any: def test_generate_no_tool_alias_ignores_mcp( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """When tool_alias is None, no MCP operations occur.""" get_mcp_called = False - class StubMCPRegistry: + class _StubMCPRegistry: def get_mcp(self, *, tool_alias: str) -> Any: nonlocal get_mcp_called get_mcp_called = True raise RuntimeError("Should not be called") - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: assert "tools" not in kwargs # No tools should be passed - return StubResponse(StubMessage(content="response without tools")) + return _make_response("response without tools") model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, - mcp_registry=StubMCPRegistry(), + client=stub_model_client, + mcp_registry=_StubMCPRegistry(), ) with patch.object(ModelFacade, "completion", new=_completion): @@ -683,17 +655,17 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_tool_calls_with_parser_corrections( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Tool calling works correctly with parser correction steps.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") parse_count = 0 responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Tool call - StubResponse(StubMessage(content="bad format")), # Parser will fail - StubResponse(StubMessage(content="correct format")), # Parser will succeed + _make_response("", tool_calls=[tool_call]), # Tool call + _make_response("bad format"), # Parser will fail + _make_response("correct format"), # Parser will succeed ] facade = StubMCPFacade() @@ -701,7 +673,7 @@ def test_generate_tool_calls_with_parser_corrections( response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -716,8 +688,8 @@ def _parser(text: str) -> str: model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -730,20 +702,18 @@ def _parser(text: str) -> str: def test_generate_tool_calls_with_conversation_restarts( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Tool calling works correctly with conversation restarts.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") messages_at_call: list[int] = [] - # First conversation: tool call + bad response - # After restart: tool call + good response responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), - StubResponse(StubMessage(content="still bad")), # Fails parser, triggers restart - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # After restart - StubResponse(StubMessage(content="good result")), + _make_response("", tool_calls=[tool_call]), + _make_response("still bad"), # Fails parser, triggers restart + _make_response("", tool_calls=[tool_call]), # After restart + _make_response("good result"), ] facade = StubMCPFacade() @@ -751,7 +721,7 @@ def test_generate_tool_calls_with_conversation_restarts( response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx messages_at_call.append(len(messages)) resp = responses[response_idx] @@ -765,8 +735,8 @@ def _parser(text: str) -> str: model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -777,7 +747,7 @@ def _parser(text: str) -> str: assert result == "good result" # After restart, message count should preserve tool call history (restart from checkpoint) - assert messages_at_call[2] == messages_at_call[1] # Both should be post-tool-call message count + assert messages_at_call[2] == messages_at_call[1] # ============================================================================= @@ -787,15 +757,15 @@ def _parser(text: str) -> str: def test_generate_trace_includes_tool_calls( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Returned trace includes tool call messages.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": '{"q": "test"}'}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json='{"q": "test"}') responses = [ - StubResponse(StubMessage(content="Let me look that up", tool_calls=[tool_call])), - StubResponse(StubMessage(content="Here is the answer")), + _make_response("Let me look that up", tool_calls=[tool_call]), + _make_response("Here is the answer"), ] facade = StubMCPFacade() @@ -803,7 +773,7 @@ def test_generate_trace_includes_tool_calls( response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -811,8 +781,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -827,20 +797,21 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_trace_includes_tool_responses( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Returned trace includes tool response messages.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") + tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), - StubResponse(StubMessage(content="final")), + _make_response("", tool_calls=[tool_call]), + _make_response("final"), ] - def custom_process_fn(completion_response: Any) -> list[ChatMessage]: + def custom_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: return [ - ChatMessage.as_assistant(content="", tool_calls=[tool_call]), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool(content="THE TOOL RESPONSE CONTENT", tool_call_id="call-1"), ] @@ -849,7 +820,7 @@ def custom_process_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -857,8 +828,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -873,20 +844,21 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_trace_includes_refusal_messages( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Returned trace includes refusal messages when budget exhausted.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") + tc_dict = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} responses = [ - StubResponse(StubMessage(content="", tool_calls=[tool_call])), # Will be refused (max_turns=0) - StubResponse(StubMessage(content="answer without tools")), + _make_response("", tool_calls=[tool_call]), # Will be refused (max_turns=0) + _make_response("answer without tools"), ] - def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: + def custom_refuse_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: return [ - ChatMessage.as_assistant(content="", tool_calls=[tool_call]), + ChatMessage.as_assistant(content="", tool_calls=[tc_dict]), ChatMessage.as_tool(content="BUDGET_EXCEEDED_REFUSAL", tool_call_id="call-1"), ] @@ -899,7 +871,7 @@ def custom_refuse_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -907,8 +879,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -922,24 +894,22 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_trace_preserves_reasoning_content( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Trace messages preserve reasoning_content field.""" - response = StubResponse( - StubMessage( - content="The answer is 42", - reasoning_content="Let me think about this carefully...", - ) + response = _make_response( + "The answer is 42", + reasoning_content="Let me think about this carefully...", ) - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: return response model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, ) with patch.object(ModelFacade, "completion", new=_completion): @@ -958,15 +928,15 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_tool_execution_error( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Handles MCP tool execution errors appropriately.""" - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "{}"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="{}") - responses = [StubResponse(StubMessage(content="", tool_calls=[tool_call]))] + responses = [_make_response("", tool_calls=[tool_call])] - def error_process_fn(completion_response: Any) -> list[ChatMessage]: + def error_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: raise MCPToolError("Tool execution failed: Connection refused") facade = StubMCPFacade(process_fn=error_process_fn) @@ -974,7 +944,7 @@ def error_process_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -982,8 +952,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -994,16 +964,15 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe def test_generate_tool_invalid_arguments( stub_model_configs: Any, - stub_secrets_resolver: Any, + stub_model_client: MagicMock, stub_model_provider_registry: Any, ) -> None: """Handles invalid tool arguments from LLM.""" - # Tool call with invalid JSON arguments - tool_call = {"id": "call-1", "type": "function", "function": {"name": "lookup", "arguments": "not valid json"}} + tool_call = ToolCall(id="call-1", name="lookup", arguments_json="not valid json") - responses = [StubResponse(StubMessage(content="", tool_calls=[tool_call]))] + responses = [_make_response("", tool_calls=[tool_call])] - def error_process_fn(completion_response: Any) -> list[ChatMessage]: + def error_process_fn(completion_response: ChatCompletionResponse) -> list[ChatMessage]: raise MCPToolError("Invalid tool arguments for 'lookup': not valid json") facade = StubMCPFacade(process_fn=error_process_fn) @@ -1011,7 +980,7 @@ def error_process_fn(completion_response: Any) -> list[ChatMessage]: response_idx = 0 - def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubResponse: + def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> ChatCompletionResponse: nonlocal response_idx resp = responses[response_idx] response_idx += 1 @@ -1019,8 +988,8 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe model = ModelFacade( model_config=stub_model_configs[0], - secret_resolver=stub_secrets_resolver, model_provider_registry=stub_model_provider_registry, + client=stub_model_client, mcp_registry=registry, ) @@ -1034,252 +1003,107 @@ def _completion(self: Any, messages: list[ChatMessage], **kwargs: Any) -> StubRe # ============================================================================= -@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) def test_generate_image_diffusion_tracks_image_usage( - mock_image_generation: Any, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test that generate_image tracks image usage for diffusion models.""" - # Mock response with 3 images - mock_response = lazy.litellm.types.utils.ImageResponse( - data=[ - lazy.litellm.types.utils.ImageObject(b64_json="image1_base64"), - lazy.litellm.types.utils.ImageObject(b64_json="image2_base64"), - lazy.litellm.types.utils.ImageObject(b64_json="image3_base64"), + stub_model_client.generate_image.return_value = ImageGenerationResponse( + images=[ + ImagePayload(b64_data="image1_base64"), + ImagePayload(b64_data="image2_base64"), + ImagePayload(b64_data="image3_base64"), ] ) - mock_image_generation.return_value = mock_response - # Verify initial state assert stub_model_facade.usage_stats.image_usage.total_images == 0 - # Generate images with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): images = stub_model_facade.generate_image(prompt="test prompt", n=3) - # Verify results assert len(images) == 3 assert images == ["image1_base64", "image2_base64", "image3_base64"] - - # Verify image usage was tracked assert stub_model_facade.usage_stats.image_usage.total_images == 3 assert stub_model_facade.usage_stats.image_usage.has_usage is True -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) def test_generate_image_chat_completion_tracks_image_usage( - mock_completion: Any, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test that generate_image tracks image usage for chat completion models.""" - # Mock response with images attribute (Message requires type and index per ImageURLListItem) - mock_message = lazy.litellm.types.utils.Message( - role="assistant", - content="", + stub_model_client.generate_image.return_value = ImageGenerationResponse( images=[ - lazy.litellm.types.utils.ImageURLListItem( - type="image_url", image_url={"url": "data:image/png;base64,image1"}, index=0 - ), - lazy.litellm.types.utils.ImageURLListItem( - type="image_url", image_url={"url": "data:image/png;base64,image2"}, index=1 - ), - ], - ) - mock_response = lazy.litellm.types.utils.ModelResponse( - choices=[lazy.litellm.types.utils.Choices(message=mock_message)] + ImagePayload(b64_data="image1"), + ImagePayload(b64_data="image2"), + ] ) - mock_completion.return_value = mock_response - # Verify initial state assert stub_model_facade.usage_stats.image_usage.total_images == 0 - # Generate images with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): images = stub_model_facade.generate_image(prompt="test prompt") - # Verify results assert len(images) == 2 assert images == ["image1", "image2"] - - # Verify image usage was tracked assert stub_model_facade.usage_stats.image_usage.total_images == 2 assert stub_model_facade.usage_stats.image_usage.has_usage is True -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) -def test_generate_image_chat_completion_with_dict_format( - mock_completion: Any, - stub_model_facade: ModelFacade, -) -> None: - """Test that generate_image handles images as dicts with image_url string.""" - # Create mock message with images as dict with string image_url - mock_message = MagicMock() - mock_message.role = "assistant" - mock_message.content = "" - mock_message.images = [ - {"image_url": "data:image/png;base64,image1"}, - {"image_url": "data:image/jpeg;base64,image2"}, - ] - - mock_choice = MagicMock() - mock_choice.message = mock_message - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - - mock_completion.return_value = mock_response - - # Generate images - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): - images = stub_model_facade.generate_image(prompt="test prompt") - - # Verify results - assert len(images) == 2 - assert images == ["image1", "image2"] - - -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) -def test_generate_image_chat_completion_with_plain_strings( - mock_completion: Any, - stub_model_facade: ModelFacade, -) -> None: - """Test that generate_image handles images as plain strings.""" - # Create mock message with images as plain strings - mock_message = MagicMock() - mock_message.role = "assistant" - mock_message.content = "" - mock_message.images = [ - "data:image/png;base64,image1", - "image2", # Plain base64 without data URI prefix - ] - - mock_choice = MagicMock() - mock_choice.message = mock_message - - mock_response = MagicMock() - mock_response.choices = [mock_choice] - - mock_completion.return_value = mock_response - - # Generate images - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): - images = stub_model_facade.generate_image(prompt="test prompt") - - # Verify results - assert len(images) == 2 - assert images == ["image1", "image2"] - - -@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) def test_generate_image_skip_usage_tracking( - mock_image_generation: Any, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test that generate_image respects skip_usage_tracking flag.""" - mock_response = lazy.litellm.types.utils.ImageResponse( - data=[ - lazy.litellm.types.utils.ImageObject(b64_json="image1_base64"), - lazy.litellm.types.utils.ImageObject(b64_json="image2_base64"), + stub_model_client.generate_image.return_value = ImageGenerationResponse( + images=[ + ImagePayload(b64_data="image1_base64"), + ImagePayload(b64_data="image2_base64"), ] ) - mock_image_generation.return_value = mock_response - # Verify initial state assert stub_model_facade.usage_stats.image_usage.total_images == 0 - # Generate images with skip_usage_tracking=True with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): images = stub_model_facade.generate_image(prompt="test prompt", skip_usage_tracking=True) - # Verify results assert len(images) == 2 - - # Verify image usage was NOT tracked assert stub_model_facade.usage_stats.image_usage.total_images == 0 assert stub_model_facade.usage_stats.image_usage.has_usage is False -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) -def test_generate_image_chat_completion_no_choices( - mock_completion: Any, - stub_model_facade: ModelFacade, -) -> None: - """Test that generate_image raises ImageGenerationError when response has no choices.""" - mock_response = lazy.litellm.types.utils.ModelResponse(choices=[]) - mock_completion.return_value = mock_response - - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): - with pytest.raises(ImageGenerationError, match="Image generation response missing choices"): - stub_model_facade.generate_image(prompt="test prompt") - - -@patch("data_designer.engine.models.facade.ModelFacade.completion", autospec=True) -def test_generate_image_chat_completion_no_image_data( - mock_completion: Any, +def test_generate_image_no_image_data( stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test that generate_image raises ImageGenerationError when no image data in response.""" - mock_message = lazy.litellm.types.utils.Message(role="assistant", content="just text, no image") - mock_response = lazy.litellm.types.utils.ModelResponse( - choices=[lazy.litellm.types.utils.Choices(message=mock_message)] - ) - mock_completion.return_value = mock_response + stub_model_client.generate_image.return_value = ImageGenerationResponse(images=[]) with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): - with pytest.raises(ImageGenerationError, match="No image data found in image generation response"): + with pytest.raises(ImageGenerationError, match="No image data found"): stub_model_facade.generate_image(prompt="test prompt") -@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) -def test_generate_image_diffusion_no_data( - mock_image_generation: Any, - stub_model_facade: ModelFacade, -) -> None: - """Test that generate_image raises ImageGenerationError when diffusion API returns no data.""" - mock_response = lazy.litellm.types.utils.ImageResponse(data=[]) - mock_image_generation.return_value = mock_response - - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): - with pytest.raises(ImageGenerationError, match="Image generation returned no data"): - stub_model_facade.generate_image(prompt="test prompt") - - -@patch("data_designer.engine.models.facade.CustomRouter.image_generation", autospec=True) def test_generate_image_accumulates_usage( - mock_image_generation: Any, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test that generate_image accumulates image usage across multiple calls.""" - # First call - 2 images - mock_response1 = lazy.litellm.types.utils.ImageResponse( - data=[ - lazy.litellm.types.utils.ImageObject(b64_json="image1"), - lazy.litellm.types.utils.ImageObject(b64_json="image2"), - ] + response1 = ImageGenerationResponse(images=[ImagePayload(b64_data="image1"), ImagePayload(b64_data="image2")]) + response2 = ImageGenerationResponse( + images=[ImagePayload(b64_data="image3"), ImagePayload(b64_data="image4"), ImagePayload(b64_data="image5")] ) - # Second call - 3 images - mock_response2 = lazy.litellm.types.utils.ImageResponse( - data=[ - lazy.litellm.types.utils.ImageObject(b64_json="image3"), - lazy.litellm.types.utils.ImageObject(b64_json="image4"), - lazy.litellm.types.utils.ImageObject(b64_json="image5"), - ] - ) - mock_image_generation.side_effect = [mock_response1, mock_response2] + stub_model_client.generate_image.side_effect = [response1, response2] - # Verify initial state assert stub_model_facade.usage_stats.image_usage.total_images == 0 - # First generation with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): images1 = stub_model_facade.generate_image(prompt="test1") assert len(images1) == 2 assert stub_model_facade.usage_stats.image_usage.total_images == 2 - # Second generation images2 = stub_model_facade.generate_image(prompt="test2") assert len(images2) == 3 - # Usage should accumulate assert stub_model_facade.usage_stats.image_usage.total_images == 5 @@ -1295,52 +1119,43 @@ def test_generate_image_accumulates_usage( True, ], ) -@patch.object(CustomRouter, "acompletion", new_callable=AsyncMock) @pytest.mark.asyncio async def test_acompletion_success( - mock_router_acompletion: AsyncMock, stub_completion_messages: list[ChatMessage], stub_model_configs: Any, stub_model_facade: ModelFacade, - stub_expected_completion_response: ModelResponse, + stub_model_client: MagicMock, skip_usage_tracking: bool, ) -> None: - mock_router_acompletion.return_value = stub_expected_completion_response + expected_response = _make_response("Test response") + stub_model_client.acompletion = AsyncMock(return_value=expected_response) result = await stub_model_facade.acompletion(stub_completion_messages, skip_usage_tracking=skip_usage_tracking) - expected_messages = [message.to_dict() for message in stub_completion_messages] - assert result == stub_expected_completion_response - assert mock_router_acompletion.call_count == 1 - assert mock_router_acompletion.call_args[1] == { - "model": "stub-model-text", - "messages": expected_messages, - **stub_model_configs[0].inference_parameters.generate_kwargs, - } + assert result == expected_response + assert stub_model_client.acompletion.call_count == 1 -@patch.object(CustomRouter, "acompletion", new_callable=AsyncMock) @pytest.mark.asyncio async def test_acompletion_with_exception( - mock_router_acompletion: AsyncMock, stub_completion_messages: list[ChatMessage], stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: - mock_router_acompletion.side_effect = Exception("Router error") + stub_model_client.acompletion = AsyncMock(side_effect=Exception("Router error")) with pytest.raises(Exception, match="Router error"): await stub_model_facade.acompletion(stub_completion_messages) -@patch.object(CustomRouter, "aembedding", new_callable=AsyncMock) @pytest.mark.asyncio async def test_agenerate_text_embeddings_success( - mock_router_aembedding: AsyncMock, stub_model_facade: ModelFacade, - stub_expected_embedding_response: EmbeddingResponse, + stub_model_client: MagicMock, ) -> None: - mock_router_aembedding.return_value = stub_expected_embedding_response + expected_vectors = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + stub_model_client.aembeddings = AsyncMock(return_value=EmbeddingResponse(vectors=expected_vectors)) input_texts = ["test1", "test2"] result = await stub_model_facade.agenerate_text_embeddings(input_texts) - assert result == [data["embedding"] for data in stub_expected_embedding_response.data] + assert result == expected_vectors @pytest.mark.parametrize( @@ -1363,7 +1178,7 @@ async def test_agenerate_correction_retries( max_conversation_restarts: int, total_calls: int, ) -> None: - bad_response = mock_oai_response_object("bad response") + bad_response = _make_response("bad response") mock_acompletion.return_value = bad_response def _failing_parser(response: str) -> str: @@ -1396,13 +1211,12 @@ async def test_agenerate_success( mock_acompletion: AsyncMock, stub_model_facade: ModelFacade, ) -> None: - good_response = mock_oai_response_object("parsed output") + good_response = _make_response("parsed output") mock_acompletion.return_value = good_response result, trace = await stub_model_facade.agenerate(prompt="test", parser=lambda x: x) assert result == "parsed output" assert mock_acompletion.call_count == 1 - # Trace should contain at least the user prompt and the assistant response assert any(msg.role == "user" for msg in trace) assert any(msg.role == "assistant" and msg.content == "parsed output" for msg in trace) @@ -1412,105 +1226,52 @@ async def test_agenerate_success( # ============================================================================= -@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock) @pytest.mark.asyncio async def test_agenerate_image_diffusion_success( - mock_aimage_generation: AsyncMock, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test async image generation via diffusion API.""" - mock_response = lazy.litellm.types.utils.ImageResponse( - data=[ - lazy.litellm.types.utils.ImageObject(b64_json="image1_base64"), - lazy.litellm.types.utils.ImageObject(b64_json="image2_base64"), - ] + stub_model_client.agenerate_image = AsyncMock( + return_value=ImageGenerationResponse( + images=[ImagePayload(b64_data="image1_base64"), ImagePayload(b64_data="image2_base64")] + ) ) - mock_aimage_generation.return_value = mock_response with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): images = await stub_model_facade.agenerate_image(prompt="test prompt") assert len(images) == 2 assert images == ["image1_base64", "image2_base64"] - assert mock_aimage_generation.call_count == 1 - # Verify image usage was tracked assert stub_model_facade.usage_stats.image_usage.total_images == 2 -@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) @pytest.mark.asyncio async def test_agenerate_image_chat_completion_success( - mock_acompletion: AsyncMock, stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: """Test async image generation via chat completion API.""" - mock_message = lazy.litellm.types.utils.Message( - role="assistant", - content="", - images=[ - lazy.litellm.types.utils.ImageURLListItem( - type="image_url", image_url={"url": "data:image/png;base64,image1"}, index=0 - ), - ], + stub_model_client.agenerate_image = AsyncMock( + return_value=ImageGenerationResponse(images=[ImagePayload(b64_data="image1")]) ) - mock_response = lazy.litellm.types.utils.ModelResponse( - choices=[lazy.litellm.types.utils.Choices(message=mock_message)] - ) - mock_acompletion.return_value = mock_response with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): images = await stub_model_facade.agenerate_image(prompt="test prompt") assert len(images) == 1 assert images == ["image1"] - assert mock_acompletion.call_count == 1 assert stub_model_facade.usage_stats.image_usage.total_images == 1 -@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock) @pytest.mark.asyncio -async def test_agenerate_image_diffusion_no_data( - mock_aimage_generation: AsyncMock, +async def test_agenerate_image_no_data( stub_model_facade: ModelFacade, + stub_model_client: MagicMock, ) -> None: - """Test async image generation raises error when diffusion API returns no data.""" - mock_response = lazy.litellm.types.utils.ImageResponse(data=[]) - mock_aimage_generation.return_value = mock_response + """Test async image generation raises error when no data.""" + stub_model_client.agenerate_image = AsyncMock(return_value=ImageGenerationResponse(images=[])) with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): - with pytest.raises(ImageGenerationError, match="Image generation returned no data"): + with pytest.raises(ImageGenerationError, match="No image data found"): await stub_model_facade.agenerate_image(prompt="test prompt") - - -@patch.object(ModelFacade, "acompletion", new_callable=AsyncMock) -@pytest.mark.asyncio -async def test_agenerate_image_chat_completion_no_choices( - mock_acompletion: AsyncMock, - stub_model_facade: ModelFacade, -) -> None: - """Test async image generation raises error when response has no choices.""" - mock_response = lazy.litellm.types.utils.ModelResponse(choices=[]) - mock_acompletion.return_value = mock_response - - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=False): - with pytest.raises(ImageGenerationError, match="Image generation response missing choices"): - await stub_model_facade.agenerate_image(prompt="test prompt") - - -@patch("data_designer.engine.models.facade.CustomRouter.aimage_generation", new_callable=AsyncMock) -@pytest.mark.asyncio -async def test_agenerate_image_skip_usage_tracking( - mock_aimage_generation: AsyncMock, - stub_model_facade: ModelFacade, -) -> None: - """Test that async image generation respects skip_usage_tracking flag.""" - mock_response = lazy.litellm.types.utils.ImageResponse( - data=[lazy.litellm.types.utils.ImageObject(b64_json="image1_base64")] - ) - mock_aimage_generation.return_value = mock_response - - with patch("data_designer.engine.models.facade.is_image_diffusion_model", return_value=True): - images = await stub_model_facade.agenerate_image(prompt="test prompt", skip_usage_tracking=True) - - assert len(images) == 1 - assert stub_model_facade.usage_stats.image_usage.total_images == 0 diff --git a/plans/343/model-facade-overhaul-pr-2-architecture-notes.md b/plans/343/model-facade-overhaul-pr-2-architecture-notes.md new file mode 100644 index 000000000..5a07a9549 --- /dev/null +++ b/plans/343/model-facade-overhaul-pr-2-architecture-notes.md @@ -0,0 +1,140 @@ +--- +date: 2026-03-04 +authors: + - nmulepati +--- + +# Model Facade Overhaul PR-2 Architecture Notes + +This document captures the architecture intent for PR-2 from +`plans/343/model-facade-overhaul-plan-step-1.md`. + +## Goal + +Switch `ModelFacade` from direct LiteLLM router usage to the `ModelClient` protocol +introduced in PR-1. After this PR, `ModelFacade` consumes only canonical types +(`ChatCompletionResponse`, `EmbeddingResponse`, `ImageGenerationResponse`) and has +no direct import or runtime dependency on LiteLLM response shapes. + +## What Changes + +### 1. ModelFacade internals rewired to ModelClient + +`ModelFacade.__init__` currently constructs a `CustomRouter` and calls it directly: + +```python +self._router = CustomRouter([self._litellm_deployment], ...) +# ... +response = self._router.completion(model=..., messages=..., **kwargs) +``` + +After PR-2, it receives a `ModelClient` (selected by factory) and builds canonical requests: + +```python +self._client: ModelClient # injected via factory +# ... +request = ChatCompletionRequest(model=..., messages=..., **consolidated) +response: ChatCompletionResponse = self._client.completion(request) +``` + +The same pattern applies to embeddings (`EmbeddingRequest` → `EmbeddingResponse`) and +image generation (`ImageGenerationRequest` → `ImageGenerationResponse`). + +### 2. Client factory + +New file: `clients/factory.py` + +Responsible for selecting the right `ModelClient` adapter given a `ModelConfig` and +provider context. For PR-2, the only adapter is `LiteLLMBridgeClient`. The factory +encapsulates router construction and deployment config that currently lives in +`ModelFacade._get_litellm_deployment`. + +`models/factory.py` (`create_model_registry`) is updated to use the client factory +when constructing each `ModelFacade`. + +### 3. MCP compatibility update + +`MCPFacade` methods (`has_tool_calls`, `tool_call_count`, `process_completion_response`, +`refuse_completion_response`) currently accept `Any` and traverse +`completion_response.choices[0].message` with `getattr` for LiteLLM shapes. + +PR-2 updates these to accept `ChatCompletionResponse` and read from canonical fields: + +- `response.message.tool_calls` → `list[ToolCall]` (id, name, arguments_json) +- `response.message.content` → `str | None` +- `response.message.reasoning_content` → `str | None` + +`_extract_tool_calls` and `_normalize_tool_call` simplify significantly because +canonical `ToolCall` is already normalized (no nested `function` key, no dict vs +object polymorphism). + +### 4. Usage tracking consolidation + +The three existing methods: + +- `_track_token_usage_from_completion` +- `_track_token_usage_from_embedding` +- `_track_token_usage_from_image_diffusion` + +All read from provider-specific usage shapes (`litellm.types.utils.*`). PR-2 replaces +them with a single helper that reads from canonical `Usage`: + +```python +def _track_usage(self, usage: Usage | None, *, is_request_successful: bool) -> None +``` + +### 5. Image extraction moves into adapter + +`ModelFacade` currently does image extraction from raw LiteLLM responses +(`_try_extract_base64`, `_generate_image_chat_completion`, `_generate_image_diffusion`). + +After PR-2, the adapter returns `ImageGenerationResponse.images: list[ImagePayload]` +with `b64_data` already resolved. `ModelFacade.generate_image` / `agenerate_image` +simply reads `response.images` and extracts `b64_data` values — no more format +detection, URL downloading, or data URI parsing at the facade level. + +### 6. LiteLLM type removal from facade + +After PR-2, `facade.py` no longer imports: + +- `litellm` (the module, currently used for type hints) +- `CustomRouter`, `LiteLLMRouterDefaultKwargs` +- `litellm.types.utils.ModelResponse`, `EmbeddingResponse`, `ImageResponse`, `ImageUsage` + +These remain internal to `LiteLLMBridgeClient` and `models/factory.py`. + +### 7. Adapter lifecycle wiring + +`ModelClient.close()` / `aclose()` are wired through `ModelRegistry` so adapter +resources (HTTP clients, connection pools) are torn down deterministically when +generation is complete. + +- `ModelRegistry` gains `close()` / `aclose()` that iterate owned facades. +- `ModelFacade` gains `close()` / `aclose()` that delegate to `self._client`. +- `ResourceProvider` (or equivalent teardown hook) calls `ModelRegistry.close()`. + +## What Does NOT Change + +1. `ModelFacade` public method signatures — callers see the same API. +2. MCP tool-loop behavior — tool turns, refusal, parallel execution all preserved. +3. Usage accounting semantics — token, request, image, and tool usage remain identical. +4. Error boundaries — `@catch_llm_exceptions` / `@acatch_llm_exceptions` decorators + and `DataDesignerError` subclass hierarchy remain stable. +5. `consolidate_kwargs` merge semantics for `extra_body` / `extra_headers`. +6. `generate` / `agenerate` parser correction/restart loop logic. + +## Files Touched + +| File | Change | +|---|---| +| `models/facade.py` | Rewire to `ModelClient`, canonical types, consolidated usage tracking | +| `models/factory.py` | Use client factory to inject `ModelClient` into `ModelFacade` | +| `models/registry.py` | Add `close` / `aclose` lifecycle methods | +| `clients/factory.py` | New — adapter selection by provider config | +| `mcp/facade.py` | Accept `ChatCompletionResponse` instead of raw LiteLLM response | + +## Planned Follow-On + +PR-3 introduces the OpenAI-compatible native adapter with shared retry/throttle +infrastructure. At that point, the client factory gains a second adapter option +alongside the LiteLLM bridge.