diff --git a/app/services/providers/anthropic_adapter.py b/app/services/providers/anthropic_adapter.py index 4533d31..62eaaad 100644 --- a/app/services/providers/anthropic_adapter.py +++ b/app/services/providers/anthropic_adapter.py @@ -9,6 +9,7 @@ from app.core.logger import get_logger from app.exceptions.exceptions import ProviderAPIException, InvalidCompletionRequestException +from app.exceptions.exceptions import BaseInvalidRequestException from .base import ProviderAdapter @@ -33,6 +34,8 @@ def __init__( def provider_name(self) -> str: return self._provider_name + + @staticmethod def convert_openai_image_content_to_anthropic( msg: dict[str, Any], @@ -126,6 +129,7 @@ async def process_completion( api_key: str, ) -> Any: """Process a completion request using Anthropic API""" + headers = { "x-api-key": api_key, "Content-Type": "application/json", @@ -161,12 +165,19 @@ async def process_completion( anthropic_messages.append({"role": "user", "content": content}) elif role == "assistant": anthropic_messages.append({"role": "assistant", "content": content}) + elif role == "tool": + # Anthropic uses "tool" role for tool responses + anthropic_messages.append({"role": "tool", "content": content}) # Add system message if present if system_message: anthropic_payload["system"] = system_message anthropic_payload["messages"] = anthropic_messages + + # Add tools if present + if "tools" in payload: + anthropic_payload["tools"] = payload["tools"] else: # Handle regular completion (legacy format) anthropic_payload["prompt"] = f"Human: {payload['prompt']}\n\nAssistant: " @@ -248,7 +259,9 @@ async def stream_response() -> AsyncGenerator[bytes, None]: ) elif event_type == "content_block_delta": - delta_content = data.get("delta", {}).get("text", "") + delta = data.get("delta", {}) + delta_content = delta.get("text", "") + if delta_content: openai_chunk = { "id": request_id, @@ -263,6 +276,31 @@ async def stream_response() -> AsyncGenerator[bytes, None]: } ], } + elif "tool_use" in delta: + # Handle tool_use delta + tool_use = delta["tool_use"] + openai_chunk = { + "id": request_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "index": 0, + "delta": { + "tool_calls": [{ + "id": tool_use.get("id", "call_1"), + "type": "function", + "function": { + "name": tool_use.get("name", ""), + "arguments": tool_use.get("input", "{}") + } + }] + }, + "finish_reason": None, + } + ], + } # Capture Output Tokens & Finish Reason from message_delta elif event_type == "message_delta": @@ -365,14 +403,40 @@ async def _process_regular_response( # Messages API response content = anthropic_response.get("content", []) text_content = "" + tool_calls = [] - # Extract text from content blocks + # Extract text and tool calls from content blocks for block in content: if block.get("type") == "text": text_content += block.get("text", "") + elif block.get("type") == "tool_use": + # Convert Anthropic tool_use to OpenAI tool_calls format + tool_use = block.get("tool_use", {}) + tool_calls.append({ + "id": tool_use.get("id", f"call_{len(tool_calls)}"), + "type": "function", + "function": { + "name": tool_use.get("name", ""), + "arguments": tool_use.get("input", "{}") + } + }) input_tokens = anthropic_response.get("usage", {}).get("input_tokens", 0) output_tokens = anthropic_response.get("usage", {}).get("output_tokens", 0) + + # Determine finish reason + finish_reason = "stop" + if tool_calls: + finish_reason = "tool_calls" + + message = { + "role": "assistant", + "content": text_content if text_content else None, + } + + if tool_calls: + message["tool_calls"] = tool_calls + return { "id": completion_id, "object": "chat.completion", @@ -381,11 +445,8 @@ async def _process_regular_response( "choices": [ { "index": 0, - "message": { - "role": "assistant", - "content": text_content, - }, - "finish_reason": "stop", + "message": message, + "finish_reason": finish_reason, } ], "usage": { diff --git a/app/services/providers/openai_adapter.py b/app/services/providers/openai_adapter.py index 933778b..6f78f3c 100644 --- a/app/services/providers/openai_adapter.py +++ b/app/services/providers/openai_adapter.py @@ -28,6 +28,8 @@ def __init__( def provider_name(self) -> str: return self._provider_name + + def get_model_id(self, payload: dict[str, Any]) -> str: """Get the model ID from the payload""" if "id" in payload: @@ -100,6 +102,7 @@ async def process_completion( query_params: dict[str, Any] = None, ) -> Any: """Process a completion request using OpenAI API""" + headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", diff --git a/tests/unit_tests/test_openai_provider.py b/tests/unit_tests/test_openai_provider.py index 3bf3416..c05c42e 100644 --- a/tests/unit_tests/test_openai_provider.py +++ b/tests/unit_tests/test_openai_provider.py @@ -2,6 +2,7 @@ import os from unittest import IsolatedAsyncioTestCase as TestCase from unittest.mock import patch +import pytest from app.services.providers.openai_adapter import OpenAIAdapter from tests.unit_tests.utils.helpers import ( @@ -108,3 +109,5 @@ async def test_chat_completion_streaming(self): expected_model="gpt-4o-mini-2024-07-18", expected_message=OPENAAI_STANDARD_CHAT_COMPLETION_RESPONSE, ) + +