Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions app/services/providers/anthropic_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ProviderAPIException,
InvalidCompletionRequestException,
)
from app.utils.translator import download_image_url

from .base import ProviderAdapter

Expand Down Expand Up @@ -57,11 +58,22 @@ def format_anthropic_usage(usage_data: dict[str, Any]) -> dict[str, Any]:
}

@staticmethod
def convert_openai_image_content_to_anthropic(
msg: dict[str, Any],
async def convert_openai_image_content_to_anthropic(
msg: dict[str, Any], allow_url_download: bool = False
) -> dict[str, Any]:
"""Convert OpenAI image content to Anthropic image content"""
data_url = msg["image_url"]["url"]
if allow_url_download:
try:
data_url = await download_image_url(logger, data_url)
except Exception as e:
logger.exception(f"Error downloading image: {e}")
raise ProviderAPIException(
provider_name="anthropic",
error_code=400,
error_message=f"Error downloading image: {e}",
)

if data_url.startswith("data:"):
# Extract media type and base64 data
parts = data_url.split(",", 1)
Expand Down Expand Up @@ -115,8 +127,9 @@ def translate_anthropic_content_to_openai(


@staticmethod
def convert_openai_content_to_anthropic(
async def convert_openai_content_to_anthropic(
content: list[dict[str, Any]] | str | None,
allow_url_download: bool = False,
) -> list[dict[str, Any]] | str:
"""Convert OpenAI content model to Anthropic content model"""
if content is None:
Expand All @@ -138,7 +151,7 @@ def convert_openai_content_to_anthropic(
result.append({"type": "text", "text": msg.get("text", "")})
elif _type == "image_url":
result.append(
AnthropicAdapter.convert_openai_image_content_to_anthropic(msg)
await AnthropicAdapter.convert_openai_image_content_to_anthropic(msg, allow_url_download=allow_url_download)
)
else:
error_message = f"{_type} is not supported"
Expand Down Expand Up @@ -189,7 +202,7 @@ async def list_models(self, api_key: str) -> list[str]:
return models

@staticmethod
def convert_openai_payload_to_anthropic(payload: dict[str, Any]) -> dict[str, Any]:
async def convert_openai_payload_to_anthropic(payload: dict[str, Any], allow_url_download: bool = False) -> dict[str, Any]:
"""Convert OpenAI completion payload to Anthropic format"""
anthropic_payload = {
"model": payload["model"],
Expand Down Expand Up @@ -319,7 +332,7 @@ def convert_openai_payload_to_anthropic(payload: dict[str, Any]) -> dict[str, An
anthropic_content = content
else:
anthropic_content = (
AnthropicAdapter.convert_openai_content_to_anthropic(content)
await AnthropicAdapter.convert_openai_content_to_anthropic(content, allow_url_download=allow_url_download)
)

anthropic_message = {"role": role, "content": anthropic_content}
Expand Down Expand Up @@ -399,7 +412,7 @@ async def process_completion(

streaming = payload.get("stream", False)
# Convert OpenAI format to Anthropic format
anthropic_payload = self.convert_openai_payload_to_anthropic(payload)
anthropic_payload = await self.convert_openai_payload_to_anthropic(payload)

# Choose the appropriate API endpoint - using ternary operator
api_endpoint = "messages" if "messages" in anthropic_payload else "complete"
Expand Down
4 changes: 1 addition & 3 deletions app/services/providers/gemini_openai_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ def __init__(
if not base_url.endswith("/openai"):
base_url = f"{base_url}/openai"

logger.debug(
"Initialised GeminiOpenAIAdapter with base_url=%s", base_url
)
logger.debug(f"Initialised GeminiOpenAIAdapter with base_url={base_url}")

super().__init__(provider_name, base_url, config=config or {})
1 change: 0 additions & 1 deletion app/services/providers/google_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
and for reference while migrating any bespoke features that haven’t yet been
replicated in the new adapter. **It will be removed in a future release.**
"""
import asyncio
import json
import os
import time
Expand Down
5 changes: 3 additions & 2 deletions app/services/providers/vertex_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ async def process_completion(self, endpoint: str, payload: dict[str, Any], api_k

streaming = payload.get("stream", False)
model_name = payload["model"]
anthropic_payload = AnthropicAdapter.convert_openai_payload_to_anthropic(payload)
anthropic_payload = await AnthropicAdapter.convert_openai_payload_to_anthropic(payload, allow_url_download=True)

# vertex specific payload
anthropic_payload["anthropic_version"] = "vertex-2023-10-16"
Expand All @@ -226,13 +226,14 @@ async def process_completion(self, endpoint: str, payload: dict[str, Any], api_k
logger.debug(f"Vertex API request - model: {model_name}, streaming: {streaming}, publisher: {self.publisher}, location: {self.location}")

def error_handler(error_text: str, http_status: int):
logger.error(f"Vertex API error - code: {http_status}, message: {error_text}")
try:
error_json = json.loads(error_text)
error_message = error_json.get("error", {}).get("message", "Unknown error")
error_code = error_json.get("error", {}).get("code", http_status)
raise ProviderAPIException("Vertex", error_code, error_message)
except Exception:
raise ProviderAPIException("Vertex", http_status, error_text)
raise ProviderAPIException("Vertex", error_code, error_message)

if streaming:
# https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamRawPredict
Expand Down
31 changes: 31 additions & 0 deletions app/utils/translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import base64
import aiohttp
from http import HTTPStatus

async def download_image_url(logger, image_url: str) -> str:
"""
Download an image from a URL and return the base64 encoded string
"""

# if the image url is a data url, return it as is
if image_url.startswith("data:"):
return image_url

async with aiohttp.ClientSession() as session:
async with session.head(image_url) as response:
if response.status != HTTPStatus.OK:
error_text = await response.text()
log_error_msg = f"Failed to fetch file metadata from URL: {error_text}"
logger.error(log_error_msg)
raise RuntimeError(log_error_msg)

mime_type = response.headers.get("Content-Type", "")
file_size = int(response.headers.get("Content-Length", 0))
if file_size > 10 * 1024 * 1024:
log_error_msg = f"Image file size is too large: {file_size} bytes"
logger.error(log_error_msg)
raise RuntimeError(log_error_msg)

async with session.get(image_url) as response:
# return format is data:mime_type;base64,base64_data
return f"data:{mime_type};base64,{base64.b64encode(await response.read()).decode('utf-8')}"
Loading