Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 68 additions & 7 deletions app/services/providers/anthropic_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from app.core.logger import get_logger
from app.exceptions.exceptions import ProviderAPIException, InvalidCompletionRequestException
from app.exceptions.exceptions import BaseInvalidRequestException

from .base import ProviderAdapter

Expand All @@ -33,6 +34,8 @@ def __init__(
def provider_name(self) -> str:
return self._provider_name



@staticmethod
def convert_openai_image_content_to_anthropic(
msg: dict[str, Any],
Expand Down Expand Up @@ -126,6 +129,7 @@ async def process_completion(
api_key: str,
) -> Any:
"""Process a completion request using Anthropic API"""

headers = {
"x-api-key": api_key,
"Content-Type": "application/json",
Expand Down Expand Up @@ -161,12 +165,19 @@ async def process_completion(
anthropic_messages.append({"role": "user", "content": content})
elif role == "assistant":
anthropic_messages.append({"role": "assistant", "content": content})
elif role == "tool":
# Anthropic uses "tool" role for tool responses
anthropic_messages.append({"role": "tool", "content": content})

# Add system message if present
if system_message:
anthropic_payload["system"] = system_message

anthropic_payload["messages"] = anthropic_messages

# Add tools if present
if "tools" in payload:
anthropic_payload["tools"] = payload["tools"]
else:
# Handle regular completion (legacy format)
anthropic_payload["prompt"] = f"Human: {payload['prompt']}\n\nAssistant: "
Expand Down Expand Up @@ -248,7 +259,9 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
)

elif event_type == "content_block_delta":
delta_content = data.get("delta", {}).get("text", "")
delta = data.get("delta", {})
delta_content = delta.get("text", "")

if delta_content:
openai_chunk = {
"id": request_id,
Expand All @@ -263,6 +276,31 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
}
],
}
elif "tool_use" in delta:
# Handle tool_use delta
tool_use = delta["tool_use"]
openai_chunk = {
"id": request_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"delta": {
"tool_calls": [{
"id": tool_use.get("id", "call_1"),
"type": "function",
"function": {
"name": tool_use.get("name", ""),
"arguments": tool_use.get("input", "{}")
}
}]
},
"finish_reason": None,
}
],
}

# Capture Output Tokens & Finish Reason from message_delta
elif event_type == "message_delta":
Expand Down Expand Up @@ -365,14 +403,40 @@ async def _process_regular_response(
# Messages API response
content = anthropic_response.get("content", [])
text_content = ""
tool_calls = []

# Extract text from content blocks
# Extract text and tool calls from content blocks
for block in content:
if block.get("type") == "text":
text_content += block.get("text", "")
elif block.get("type") == "tool_use":
# Convert Anthropic tool_use to OpenAI tool_calls format
tool_use = block.get("tool_use", {})
tool_calls.append({
"id": tool_use.get("id", f"call_{len(tool_calls)}"),
"type": "function",
"function": {
"name": tool_use.get("name", ""),
"arguments": tool_use.get("input", "{}")
}
})

input_tokens = anthropic_response.get("usage", {}).get("input_tokens", 0)
output_tokens = anthropic_response.get("usage", {}).get("output_tokens", 0)

# Determine finish reason
finish_reason = "stop"
if tool_calls:
finish_reason = "tool_calls"

message = {
"role": "assistant",
"content": text_content if text_content else None,
}

if tool_calls:
message["tool_calls"] = tool_calls

return {
"id": completion_id,
"object": "chat.completion",
Expand All @@ -381,11 +445,8 @@ async def _process_regular_response(
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": text_content,
},
"finish_reason": "stop",
"message": message,
"finish_reason": finish_reason,
}
],
"usage": {
Expand Down
3 changes: 3 additions & 0 deletions app/services/providers/openai_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def __init__(
def provider_name(self) -> str:
return self._provider_name



def get_model_id(self, payload: dict[str, Any]) -> str:
"""Get the model ID from the payload"""
if "id" in payload:
Expand Down Expand Up @@ -100,6 +102,7 @@ async def process_completion(
query_params: dict[str, Any] = None,
) -> Any:
"""Process a completion request using OpenAI API"""

headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
Expand Down
3 changes: 3 additions & 0 deletions tests/unit_tests/test_openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from unittest import IsolatedAsyncioTestCase as TestCase
from unittest.mock import patch
import pytest

from app.services.providers.openai_adapter import OpenAIAdapter
from tests.unit_tests.utils.helpers import (
Expand Down Expand Up @@ -108,3 +109,5 @@ async def test_chat_completion_streaming(self):
expected_model="gpt-4o-mini-2024-07-18",
expected_message=OPENAAI_STANDARD_CHAT_COMPLETION_RESPONSE,
)


Loading