From 41be850622871db20879b3c6a4fbfef13d11eeab Mon Sep 17 00:00:00 2001 From: Wenjing Yu Date: Mon, 14 Jul 2025 16:04:39 -0700 Subject: [PATCH 1/2] add tool support for openai and anthropic --- app/services/providers/anthropic_adapter.py | 152 +++++++++++- app/services/providers/openai_adapter.py | 131 ++++++++++ tests/unit_tests/test_openai_provider.py | 257 ++++++++++++++++++++ 3 files changed, 533 insertions(+), 7 deletions(-) diff --git a/app/services/providers/anthropic_adapter.py b/app/services/providers/anthropic_adapter.py index 4533d31..2514d64 100644 --- a/app/services/providers/anthropic_adapter.py +++ b/app/services/providers/anthropic_adapter.py @@ -9,6 +9,7 @@ from app.core.logger import get_logger from app.exceptions.exceptions import ProviderAPIException, InvalidCompletionRequestException +from app.exceptions.exceptions import BaseInvalidRequestException from .base import ProviderAdapter @@ -33,6 +34,78 @@ def __init__( def provider_name(self) -> str: return self._provider_name + def validate_tools(self, tools: list[dict[str, Any]]) -> None: + """Validate tools structure for Anthropic API compatibility""" + if not tools: + return + + for i, tool in enumerate(tools): + if not isinstance(tool, dict): + error_msg = f"Tool at index {i} must be a dictionary" + logger.error(f"Anthropic API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + + # Anthropic uses "function" type for tools + if "function" not in tool: + error_msg = f"Tool at index {i} must have a 'function' object" + logger.error(f"Anthropic API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + + function = tool["function"] + if not isinstance(function, dict): + error_msg = f"Function at index {i} must be a dictionary" + logger.error(f"Anthropic API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + + function_name = function.get("name") + if not function_name or not isinstance(function_name, str): + error_msg = f"Function at index {i} must have a valid 'name' string" + logger.error(f"Anthropic API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + + def validate_messages(self, messages: list[dict[str, Any]]) -> None: + """Validate message structure for Anthropic API compatibility""" + if not messages: + return + + for i, message in enumerate(messages): + role = message.get("role") + + # Check for tool messages that don't have proper preceding tool_calls + if role == "tool": + # Find the preceding assistant message with tool_calls + has_preceding_tool_calls = False + for j in range(i - 1, -1, -1): + prev_message = messages[j] + if prev_message.get("role") == "assistant": + if "tool_calls" in prev_message: + has_preceding_tool_calls = True + break + elif "content" in prev_message: + # If assistant message has content but no tool_calls, + # it's not a valid preceding message for tool role + break + + if not has_preceding_tool_calls: + error_msg = f"Message at index {i} with role 'tool' must be a response to a preceding message with 'tool_calls'" + logger.error(f"Anthropic API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + @staticmethod def convert_openai_image_content_to_anthropic( msg: dict[str, Any], @@ -126,6 +199,14 @@ async def process_completion( api_key: str, ) -> Any: """Process a completion request using Anthropic API""" + # Validate tools if present + if "tools" in payload: + self.validate_tools(payload["tools"]) + + # Validate messages for tool role if present + if "messages" in payload: + self.validate_messages(payload["messages"]) + headers = { "x-api-key": api_key, "Content-Type": "application/json", @@ -161,12 +242,19 @@ async def process_completion( anthropic_messages.append({"role": "user", "content": content}) elif role == "assistant": anthropic_messages.append({"role": "assistant", "content": content}) + elif role == "tool": + # Anthropic uses "tool" role for tool responses + anthropic_messages.append({"role": "tool", "content": content}) # Add system message if present if system_message: anthropic_payload["system"] = system_message anthropic_payload["messages"] = anthropic_messages + + # Add tools if present + if "tools" in payload: + anthropic_payload["tools"] = payload["tools"] else: # Handle regular completion (legacy format) anthropic_payload["prompt"] = f"Human: {payload['prompt']}\n\nAssistant: " @@ -248,7 +336,9 @@ async def stream_response() -> AsyncGenerator[bytes, None]: ) elif event_type == "content_block_delta": - delta_content = data.get("delta", {}).get("text", "") + delta = data.get("delta", {}) + delta_content = delta.get("text", "") + if delta_content: openai_chunk = { "id": request_id, @@ -263,6 +353,31 @@ async def stream_response() -> AsyncGenerator[bytes, None]: } ], } + elif "tool_use" in delta: + # Handle tool_use delta + tool_use = delta["tool_use"] + openai_chunk = { + "id": request_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{ + "id": tool_use.get("id", "call_1"), + "type": "function", + "function": { + "name": tool_use.get("name", ""), + "arguments": tool_use.get("input", "{}") + } + }] + }, + "finish_reason": None, + } + ], + } # Capture Output Tokens & Finish Reason from message_delta elif event_type == "message_delta": @@ -365,14 +480,40 @@ async def _process_regular_response( # Messages API response content = anthropic_response.get("content", []) text_content = "" + tool_calls = [] - # Extract text from content blocks + # Extract text and tool calls from content blocks for block in content: if block.get("type") == "text": text_content += block.get("text", "") + elif block.get("type") == "tool_use": + # Convert Anthropic tool_use to OpenAI tool_calls format + tool_use = block.get("tool_use", {}) + tool_calls.append({ + "id": tool_use.get("id", f"call_{len(tool_calls)}"), + "type": "function", + "function": { + "name": tool_use.get("name", ""), + "arguments": tool_use.get("input", "{}") + } + }) input_tokens = anthropic_response.get("usage", {}).get("input_tokens", 0) output_tokens = anthropic_response.get("usage", {}).get("output_tokens", 0) + + # Determine finish reason + finish_reason = "stop" + if tool_calls: + finish_reason = "tool_calls" + + message = { + "role": "assistant", + "content": text_content if text_content else None, + } + + if tool_calls: + message["tool_calls"] = tool_calls + return { "id": completion_id, "object": "chat.completion", @@ -381,11 +522,8 @@ async def _process_regular_response( "choices": [ { "index": 0, - "message": { - "role": "assistant", - "content": text_content, - }, - "finish_reason": "stop", + "message": message, + "finish_reason": finish_reason, } ], "usage": { diff --git a/app/services/providers/openai_adapter.py b/app/services/providers/openai_adapter.py index 933778b..47caf3e 100644 --- a/app/services/providers/openai_adapter.py +++ b/app/services/providers/openai_adapter.py @@ -28,6 +28,125 @@ def __init__( def provider_name(self) -> str: return self._provider_name + def validate_messages(self, messages: list[dict[str, Any]]) -> None: + """Validate message structure for OpenAI API compatibility""" + if not messages: + return + + for i, message in enumerate(messages): + role = message.get("role") + + # Check for tool messages that don't have proper preceding tool_calls + if role == "tool": + # Find the preceding assistant message with tool_calls + has_preceding_tool_calls = False + for j in range(i - 1, -1, -1): + prev_message = messages[j] + if prev_message.get("role") == "assistant": + if "tool_calls" in prev_message: + has_preceding_tool_calls = True + break + elif "content" in prev_message: + # If assistant message has content but no tool_calls, + # it's not a valid preceding message for tool role + break + + if not has_preceding_tool_calls: + error_msg = f"Message at index {i} with role 'tool' must be a response to a preceding message with 'tool_calls'" + logger.error(f"OpenAI API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + + def validate_tools(self, tools: list[dict[str, Any]]) -> None: + """Validate tools structure for OpenAI API compatibility""" + if not tools: + return + + for i, tool in enumerate(tools): + if not isinstance(tool, dict): + error_msg = f"Tool at index {i} must be a dictionary" + logger.error(f"OpenAI API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + + tool_type = tool.get("type") + if tool_type != "function": + error_msg = f"Tool at index {i} has unsupported type '{tool_type}'. Only 'function' type is supported" + logger.error(f"OpenAI API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + + function = tool.get("function") + if not function or not isinstance(function, dict): + error_msg = f"Tool at index {i} must have a 'function' object" + logger.error(f"OpenAI API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + + function_name = function.get("name") + if not function_name or not isinstance(function_name, str): + error_msg = f"Function at index {i} must have a valid 'name' string" + logger.error(f"OpenAI API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + + def validate_tool_choice(self, tool_choice: Any) -> None: + """Validate tool_choice parameter for OpenAI API compatibility""" + if tool_choice is None: + return + + if isinstance(tool_choice, str): + valid_choices = ["none", "auto"] + if tool_choice not in valid_choices: + error_msg = f"tool_choice must be one of {valid_choices}, got '{tool_choice}'" + logger.error(f"OpenAI API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + elif isinstance(tool_choice, dict): + if "type" not in tool_choice: + error_msg = "tool_choice object must have a 'type' field" + logger.error(f"OpenAI API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + + tool_choice_type = tool_choice.get("type") + if tool_choice_type == "function": + if "function" not in tool_choice: + error_msg = "tool_choice with type 'function' must have a 'function' object" + logger.error(f"OpenAI API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + elif tool_choice_type not in ["none", "auto"]: + error_msg = f"tool_choice type must be one of ['none', 'auto', 'function'], got '{tool_choice_type}'" + logger.error(f"OpenAI API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + else: + error_msg = f"tool_choice must be a string or object, got {type(tool_choice).__name__}" + logger.error(f"OpenAI API validation error: {error_msg}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError(error_msg) + ) + def get_model_id(self, payload: dict[str, Any]) -> str: """Get the model ID from the payload""" if "id" in payload: @@ -100,6 +219,18 @@ async def process_completion( query_params: dict[str, Any] = None, ) -> Any: """Process a completion request using OpenAI API""" + # Validate messages before sending to API + if "messages" in payload: + self.validate_messages(payload["messages"]) + + # Validate tools if present + if "tools" in payload: + self.validate_tools(payload["tools"]) + + # Validate tool_choice if present + if "tool_choice" in payload: + self.validate_tool_choice(payload["tool_choice"]) + headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", diff --git a/tests/unit_tests/test_openai_provider.py b/tests/unit_tests/test_openai_provider.py index 3bf3416..8e2ac99 100644 --- a/tests/unit_tests/test_openai_provider.py +++ b/tests/unit_tests/test_openai_provider.py @@ -2,6 +2,7 @@ import os from unittest import IsolatedAsyncioTestCase as TestCase from unittest.mock import patch +import pytest from app.services.providers.openai_adapter import OpenAIAdapter from tests.unit_tests.utils.helpers import ( @@ -108,3 +109,259 @@ async def test_chat_completion_streaming(self): expected_model="gpt-4o-mini-2024-07-18", expected_message=OPENAAI_STANDARD_CHAT_COMPLETION_RESPONSE, ) + + def test_openai_adapter_tool_message_validation(self): + """Test that OpenAI adapter validates tool messages correctly""" + from app.services.providers.openai_adapter import OpenAIAdapter + from app.exceptions.exceptions import BaseInvalidRequestException + + adapter = OpenAIAdapter("openai", "https://api.openai.com/v1") + + # Test valid tool message sequence + valid_messages = [ + {"role": "user", "content": "What's the weather like?"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]}, + {"role": "tool", "content": "Sunny", "tool_call_id": "call_1"} + ] + + # This should not raise an exception + adapter.validate_messages(valid_messages) + + # Test invalid tool message sequence (no preceding tool_calls) + invalid_messages = [ + {"role": "user", "content": "What's the weather like?"}, + {"role": "assistant", "content": "I'll check the weather for you."}, + {"role": "tool", "content": "Sunny", "tool_call_id": "call_1"} # This should fail + ] + + # This should raise an exception + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_messages(invalid_messages) + + assert "tool' must be a response to a preceding message with 'tool_calls'" in str(exc_info.value) + + # Test tool message with assistant message that has content but no tool_calls + invalid_messages_2 = [ + {"role": "user", "content": "What's the weather like?"}, + {"role": "assistant", "content": "Let me check that for you."}, + {"role": "tool", "content": "Sunny", "tool_call_id": "call_1"} # This should fail + ] + + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_messages(invalid_messages_2) + + assert "tool' must be a response to a preceding message with 'tool_calls'" in str(exc_info.value) + + def test_openai_adapter_tools_validation(self): + """Test that OpenAI adapter validates tools correctly""" + from app.services.providers.openai_adapter import OpenAIAdapter + from app.exceptions.exceptions import BaseInvalidRequestException + + adapter = OpenAIAdapter("openai", "https://api.openai.com/v1") + + # Test valid tools + valid_tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state"} + }, + "required": ["location"] + } + } + } + ] + + # This should not raise an exception + adapter.validate_tools(valid_tools) + + # Test invalid tool (wrong type) + invalid_tools_1 = [ + { + "type": "retrieval", # Wrong type + "function": { + "name": "get_weather", + "description": "Get the weather" + } + } + ] + + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_tools(invalid_tools_1) + + assert "unsupported type 'retrieval'. Only 'function' type is supported" in str(exc_info.value) + + # Test invalid tool (missing function) + invalid_tools_2 = [ + { + "type": "function" + # Missing function object + } + ] + + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_tools(invalid_tools_2) + + assert "must have a 'function' object" in str(exc_info.value) + + # Test invalid tool (missing function name) + invalid_tools_3 = [ + { + "type": "function", + "function": { + "description": "Get the weather" + # Missing name + } + } + ] + + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_tools(invalid_tools_3) + + assert "must have a valid 'name' string" in str(exc_info.value) + + # Test empty tools list + adapter.validate_tools([]) + + def test_openai_adapter_tool_choice_validation(self): + """Test that OpenAI adapter validates tool_choice correctly""" + from app.services.providers.openai_adapter import OpenAIAdapter + from app.exceptions.exceptions import BaseInvalidRequestException + + adapter = OpenAIAdapter("openai", "https://api.openai.com/v1") + + # Test valid string tool_choice values + valid_choices = ["none", "auto"] + for choice in valid_choices: + adapter.validate_tool_choice(choice) + + # Test valid object tool_choice + valid_object_choice = { + "type": "function", + "function": {"name": "get_weather"} + } + adapter.validate_tool_choice(valid_object_choice) + + # Test invalid string tool_choice + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_tool_choice("invalid") + + assert "tool_choice must be one of ['none', 'auto'], got 'invalid'" in str(exc_info.value) + + # Test invalid object tool_choice (missing type) + invalid_object_choice_1 = { + "function": {"name": "get_weather"} + # Missing type + } + + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_tool_choice(invalid_object_choice_1) + + assert "must have a 'type' field" in str(exc_info.value) + + # Test invalid object tool_choice (function type without function) + invalid_object_choice_2 = { + "type": "function" + # Missing function object + } + + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_tool_choice(invalid_object_choice_2) + + assert "must have a 'function' object" in str(exc_info.value) + + # Test invalid object tool_choice (invalid type) + invalid_object_choice_3 = { + "type": "invalid_type" + } + + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_tool_choice(invalid_object_choice_3) + + assert "type must be one of ['none', 'auto', 'function'], got 'invalid_type'" in str(exc_info.value) + + # Test None tool_choice + adapter.validate_tool_choice(None) + + + def test_anthropic_adapter_function_calling_validation(self): + """Test that Anthropic adapter validates function calling correctly""" + from app.services.providers.anthropic_adapter import AnthropicAdapter + from app.exceptions.exceptions import BaseInvalidRequestException + + adapter = AnthropicAdapter("anthropic", "https://api.anthropic.com", {}) + + # Test valid tools + valid_tools = [ + { + "function": { + "name": "get_weather", + "description": "Get the weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state"} + }, + "required": ["location"] + } + } + } + ] + + # This should not raise an exception + adapter.validate_tools(valid_tools) + + # Test invalid tool (missing function) + invalid_tools_1 = [ + { + "type": "function" + # Missing function object + } + ] + + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_tools(invalid_tools_1) + + assert "must have a 'function' object" in str(exc_info.value) + + # Test invalid tool (missing function name) + invalid_tools_2 = [ + { + "function": { + "description": "Get the weather" + # Missing name + } + } + ] + + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_tools(invalid_tools_2) + + assert "must have a valid 'name' string" in str(exc_info.value) + + # Test valid tool message sequence + valid_messages = [ + {"role": "user", "content": "What's the weather like?"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]}, + {"role": "tool", "content": "Sunny", "tool_call_id": "call_1"} + ] + + # This should not raise an exception + adapter.validate_messages(valid_messages) + + # Test invalid tool message sequence (no preceding tool_calls) + invalid_messages = [ + {"role": "user", "content": "What's the weather like?"}, + {"role": "assistant", "content": "I'll check the weather for you."}, + {"role": "tool", "content": "Sunny", "tool_call_id": "call_1"} # This should fail + ] + + with pytest.raises(BaseInvalidRequestException) as exc_info: + adapter.validate_messages(invalid_messages) + + assert "tool' must be a response to a preceding message with 'tool_calls'" in str(exc_info.value) From c4a3418f1cd1be2575fd921b0675523ab8deb451 Mon Sep 17 00:00:00 2001 From: Wenjing Yu Date: Mon, 14 Jul 2025 17:18:29 -0700 Subject: [PATCH 2/2] fix error --- app/services/providers/anthropic_adapter.py | 77 ------ app/services/providers/openai_adapter.py | 128 ---------- tests/unit_tests/test_openai_provider.py | 254 -------------------- 3 files changed, 459 deletions(-) diff --git a/app/services/providers/anthropic_adapter.py b/app/services/providers/anthropic_adapter.py index 2514d64..62eaaad 100644 --- a/app/services/providers/anthropic_adapter.py +++ b/app/services/providers/anthropic_adapter.py @@ -34,77 +34,7 @@ def __init__( def provider_name(self) -> str: return self._provider_name - def validate_tools(self, tools: list[dict[str, Any]]) -> None: - """Validate tools structure for Anthropic API compatibility""" - if not tools: - return - - for i, tool in enumerate(tools): - if not isinstance(tool, dict): - error_msg = f"Tool at index {i} must be a dictionary" - logger.error(f"Anthropic API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - - # Anthropic uses "function" type for tools - if "function" not in tool: - error_msg = f"Tool at index {i} must have a 'function' object" - logger.error(f"Anthropic API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - - function = tool["function"] - if not isinstance(function, dict): - error_msg = f"Function at index {i} must be a dictionary" - logger.error(f"Anthropic API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - - function_name = function.get("name") - if not function_name or not isinstance(function_name, str): - error_msg = f"Function at index {i} must have a valid 'name' string" - logger.error(f"Anthropic API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - def validate_messages(self, messages: list[dict[str, Any]]) -> None: - """Validate message structure for Anthropic API compatibility""" - if not messages: - return - - for i, message in enumerate(messages): - role = message.get("role") - - # Check for tool messages that don't have proper preceding tool_calls - if role == "tool": - # Find the preceding assistant message with tool_calls - has_preceding_tool_calls = False - for j in range(i - 1, -1, -1): - prev_message = messages[j] - if prev_message.get("role") == "assistant": - if "tool_calls" in prev_message: - has_preceding_tool_calls = True - break - elif "content" in prev_message: - # If assistant message has content but no tool_calls, - # it's not a valid preceding message for tool role - break - - if not has_preceding_tool_calls: - error_msg = f"Message at index {i} with role 'tool' must be a response to a preceding message with 'tool_calls'" - logger.error(f"Anthropic API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) @staticmethod def convert_openai_image_content_to_anthropic( @@ -199,13 +129,6 @@ async def process_completion( api_key: str, ) -> Any: """Process a completion request using Anthropic API""" - # Validate tools if present - if "tools" in payload: - self.validate_tools(payload["tools"]) - - # Validate messages for tool role if present - if "messages" in payload: - self.validate_messages(payload["messages"]) headers = { "x-api-key": api_key, diff --git a/app/services/providers/openai_adapter.py b/app/services/providers/openai_adapter.py index 47caf3e..6f78f3c 100644 --- a/app/services/providers/openai_adapter.py +++ b/app/services/providers/openai_adapter.py @@ -28,124 +28,7 @@ def __init__( def provider_name(self) -> str: return self._provider_name - def validate_messages(self, messages: list[dict[str, Any]]) -> None: - """Validate message structure for OpenAI API compatibility""" - if not messages: - return - - for i, message in enumerate(messages): - role = message.get("role") - - # Check for tool messages that don't have proper preceding tool_calls - if role == "tool": - # Find the preceding assistant message with tool_calls - has_preceding_tool_calls = False - for j in range(i - 1, -1, -1): - prev_message = messages[j] - if prev_message.get("role") == "assistant": - if "tool_calls" in prev_message: - has_preceding_tool_calls = True - break - elif "content" in prev_message: - # If assistant message has content but no tool_calls, - # it's not a valid preceding message for tool role - break - - if not has_preceding_tool_calls: - error_msg = f"Message at index {i} with role 'tool' must be a response to a preceding message with 'tool_calls'" - logger.error(f"OpenAI API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - def validate_tools(self, tools: list[dict[str, Any]]) -> None: - """Validate tools structure for OpenAI API compatibility""" - if not tools: - return - - for i, tool in enumerate(tools): - if not isinstance(tool, dict): - error_msg = f"Tool at index {i} must be a dictionary" - logger.error(f"OpenAI API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - - tool_type = tool.get("type") - if tool_type != "function": - error_msg = f"Tool at index {i} has unsupported type '{tool_type}'. Only 'function' type is supported" - logger.error(f"OpenAI API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - - function = tool.get("function") - if not function or not isinstance(function, dict): - error_msg = f"Tool at index {i} must have a 'function' object" - logger.error(f"OpenAI API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - - function_name = function.get("name") - if not function_name or not isinstance(function_name, str): - error_msg = f"Function at index {i} must have a valid 'name' string" - logger.error(f"OpenAI API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - - def validate_tool_choice(self, tool_choice: Any) -> None: - """Validate tool_choice parameter for OpenAI API compatibility""" - if tool_choice is None: - return - - if isinstance(tool_choice, str): - valid_choices = ["none", "auto"] - if tool_choice not in valid_choices: - error_msg = f"tool_choice must be one of {valid_choices}, got '{tool_choice}'" - logger.error(f"OpenAI API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - elif isinstance(tool_choice, dict): - if "type" not in tool_choice: - error_msg = "tool_choice object must have a 'type' field" - logger.error(f"OpenAI API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - - tool_choice_type = tool_choice.get("type") - if tool_choice_type == "function": - if "function" not in tool_choice: - error_msg = "tool_choice with type 'function' must have a 'function' object" - logger.error(f"OpenAI API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - elif tool_choice_type not in ["none", "auto"]: - error_msg = f"tool_choice type must be one of ['none', 'auto', 'function'], got '{tool_choice_type}'" - logger.error(f"OpenAI API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) - else: - error_msg = f"tool_choice must be a string or object, got {type(tool_choice).__name__}" - logger.error(f"OpenAI API validation error: {error_msg}") - raise BaseInvalidRequestException( - provider_name=self.provider_name, - error=ValueError(error_msg) - ) def get_model_id(self, payload: dict[str, Any]) -> str: """Get the model ID from the payload""" @@ -219,17 +102,6 @@ async def process_completion( query_params: dict[str, Any] = None, ) -> Any: """Process a completion request using OpenAI API""" - # Validate messages before sending to API - if "messages" in payload: - self.validate_messages(payload["messages"]) - - # Validate tools if present - if "tools" in payload: - self.validate_tools(payload["tools"]) - - # Validate tool_choice if present - if "tool_choice" in payload: - self.validate_tool_choice(payload["tool_choice"]) headers = { "Authorization": f"Bearer {api_key}", diff --git a/tests/unit_tests/test_openai_provider.py b/tests/unit_tests/test_openai_provider.py index 8e2ac99..c05c42e 100644 --- a/tests/unit_tests/test_openai_provider.py +++ b/tests/unit_tests/test_openai_provider.py @@ -110,258 +110,4 @@ async def test_chat_completion_streaming(self): expected_message=OPENAAI_STANDARD_CHAT_COMPLETION_RESPONSE, ) - def test_openai_adapter_tool_message_validation(self): - """Test that OpenAI adapter validates tool messages correctly""" - from app.services.providers.openai_adapter import OpenAIAdapter - from app.exceptions.exceptions import BaseInvalidRequestException - - adapter = OpenAIAdapter("openai", "https://api.openai.com/v1") - - # Test valid tool message sequence - valid_messages = [ - {"role": "user", "content": "What's the weather like?"}, - {"role": "assistant", "content": None, "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]}, - {"role": "tool", "content": "Sunny", "tool_call_id": "call_1"} - ] - - # This should not raise an exception - adapter.validate_messages(valid_messages) - - # Test invalid tool message sequence (no preceding tool_calls) - invalid_messages = [ - {"role": "user", "content": "What's the weather like?"}, - {"role": "assistant", "content": "I'll check the weather for you."}, - {"role": "tool", "content": "Sunny", "tool_call_id": "call_1"} # This should fail - ] - - # This should raise an exception - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_messages(invalid_messages) - - assert "tool' must be a response to a preceding message with 'tool_calls'" in str(exc_info.value) - - # Test tool message with assistant message that has content but no tool_calls - invalid_messages_2 = [ - {"role": "user", "content": "What's the weather like?"}, - {"role": "assistant", "content": "Let me check that for you."}, - {"role": "tool", "content": "Sunny", "tool_call_id": "call_1"} # This should fail - ] - - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_messages(invalid_messages_2) - - assert "tool' must be a response to a preceding message with 'tool_calls'" in str(exc_info.value) - - def test_openai_adapter_tools_validation(self): - """Test that OpenAI adapter validates tools correctly""" - from app.services.providers.openai_adapter import OpenAIAdapter - from app.exceptions.exceptions import BaseInvalidRequestException - - adapter = OpenAIAdapter("openai", "https://api.openai.com/v1") - - # Test valid tools - valid_tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the weather for a location", - "parameters": { - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state"} - }, - "required": ["location"] - } - } - } - ] - - # This should not raise an exception - adapter.validate_tools(valid_tools) - - # Test invalid tool (wrong type) - invalid_tools_1 = [ - { - "type": "retrieval", # Wrong type - "function": { - "name": "get_weather", - "description": "Get the weather" - } - } - ] - - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_tools(invalid_tools_1) - - assert "unsupported type 'retrieval'. Only 'function' type is supported" in str(exc_info.value) - - # Test invalid tool (missing function) - invalid_tools_2 = [ - { - "type": "function" - # Missing function object - } - ] - - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_tools(invalid_tools_2) - - assert "must have a 'function' object" in str(exc_info.value) - - # Test invalid tool (missing function name) - invalid_tools_3 = [ - { - "type": "function", - "function": { - "description": "Get the weather" - # Missing name - } - } - ] - - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_tools(invalid_tools_3) - - assert "must have a valid 'name' string" in str(exc_info.value) - - # Test empty tools list - adapter.validate_tools([]) - def test_openai_adapter_tool_choice_validation(self): - """Test that OpenAI adapter validates tool_choice correctly""" - from app.services.providers.openai_adapter import OpenAIAdapter - from app.exceptions.exceptions import BaseInvalidRequestException - - adapter = OpenAIAdapter("openai", "https://api.openai.com/v1") - - # Test valid string tool_choice values - valid_choices = ["none", "auto"] - for choice in valid_choices: - adapter.validate_tool_choice(choice) - - # Test valid object tool_choice - valid_object_choice = { - "type": "function", - "function": {"name": "get_weather"} - } - adapter.validate_tool_choice(valid_object_choice) - - # Test invalid string tool_choice - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_tool_choice("invalid") - - assert "tool_choice must be one of ['none', 'auto'], got 'invalid'" in str(exc_info.value) - - # Test invalid object tool_choice (missing type) - invalid_object_choice_1 = { - "function": {"name": "get_weather"} - # Missing type - } - - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_tool_choice(invalid_object_choice_1) - - assert "must have a 'type' field" in str(exc_info.value) - - # Test invalid object tool_choice (function type without function) - invalid_object_choice_2 = { - "type": "function" - # Missing function object - } - - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_tool_choice(invalid_object_choice_2) - - assert "must have a 'function' object" in str(exc_info.value) - - # Test invalid object tool_choice (invalid type) - invalid_object_choice_3 = { - "type": "invalid_type" - } - - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_tool_choice(invalid_object_choice_3) - - assert "type must be one of ['none', 'auto', 'function'], got 'invalid_type'" in str(exc_info.value) - - # Test None tool_choice - adapter.validate_tool_choice(None) - - - def test_anthropic_adapter_function_calling_validation(self): - """Test that Anthropic adapter validates function calling correctly""" - from app.services.providers.anthropic_adapter import AnthropicAdapter - from app.exceptions.exceptions import BaseInvalidRequestException - - adapter = AnthropicAdapter("anthropic", "https://api.anthropic.com", {}) - - # Test valid tools - valid_tools = [ - { - "function": { - "name": "get_weather", - "description": "Get the weather for a location", - "parameters": { - "type": "object", - "properties": { - "location": {"type": "string", "description": "The city and state"} - }, - "required": ["location"] - } - } - } - ] - - # This should not raise an exception - adapter.validate_tools(valid_tools) - - # Test invalid tool (missing function) - invalid_tools_1 = [ - { - "type": "function" - # Missing function object - } - ] - - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_tools(invalid_tools_1) - - assert "must have a 'function' object" in str(exc_info.value) - - # Test invalid tool (missing function name) - invalid_tools_2 = [ - { - "function": { - "description": "Get the weather" - # Missing name - } - } - ] - - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_tools(invalid_tools_2) - - assert "must have a valid 'name' string" in str(exc_info.value) - - # Test valid tool message sequence - valid_messages = [ - {"role": "user", "content": "What's the weather like?"}, - {"role": "assistant", "content": None, "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{}"}}]}, - {"role": "tool", "content": "Sunny", "tool_call_id": "call_1"} - ] - - # This should not raise an exception - adapter.validate_messages(valid_messages) - - # Test invalid tool message sequence (no preceding tool_calls) - invalid_messages = [ - {"role": "user", "content": "What's the weather like?"}, - {"role": "assistant", "content": "I'll check the weather for you."}, - {"role": "tool", "content": "Sunny", "tool_call_id": "call_1"} # This should fail - ] - - with pytest.raises(BaseInvalidRequestException) as exc_info: - adapter.validate_messages(invalid_messages) - - assert "tool' must be a response to a preceding message with 'tool_calls'" in str(exc_info.value)