Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 1 addition & 41 deletions app/api/schemas/provider_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,6 @@

logger = get_logger(name="provider_key")

# Constants for API key masking
API_KEY_MASK_PREFIX_LENGTH = 2
API_KEY_MASK_SUFFIX_LENGTH = 4
# Minimum length to apply the full prefix + suffix mask (e.g., pr******fix)
# This means if length is > (PREFIX + SUFFIX), we can apply the full rule.
MIN_KEY_LENGTH_FOR_FULL_MASK_LOGIC = (
API_KEY_MASK_PREFIX_LENGTH + API_KEY_MASK_SUFFIX_LENGTH
)


# Helper function for masking API keys
def _mask_api_key_value(value: str | None) -> str | None:
if not value:
return None

length = len(value)

if length == 0:
return ""

# If key is too short for any meaningful prefix/suffix masking
if length <= API_KEY_MASK_PREFIX_LENGTH:
return "*" * length

# If key is long enough for prefix, but not for prefix + suffix
# e.g., length is 3, 4, 5, 6. For these, show prefix and mask the rest.
if length <= MIN_KEY_LENGTH_FOR_FULL_MASK_LOGIC:
return value[:API_KEY_MASK_PREFIX_LENGTH] + "*" * (
length - API_KEY_MASK_PREFIX_LENGTH
)

# If key is long enough for the full prefix + ... + suffix mask
# number of asterisks = length - prefix_length - suffix_length
num_asterisks = length - API_KEY_MASK_PREFIX_LENGTH - API_KEY_MASK_SUFFIX_LENGTH
return (
value[:API_KEY_MASK_PREFIX_LENGTH]
+ "*" * num_asterisks
+ value[-API_KEY_MASK_SUFFIX_LENGTH:]
)


class ProviderKeyBase(BaseModel):
provider_name: str = Field(..., min_length=1)
Expand Down Expand Up @@ -109,7 +69,7 @@ def api_key(self) -> str | None:
api_key, _ = provider_adapter_cls.deserialize_api_key_config(
decrypted_value
)
return _mask_api_key_value(api_key)
return provider_adapter_cls.mask_api_key(api_key)
except Exception as e:
logger.error(
f"Error deserializing API key for provider {self.provider_name}: {e}"
Expand Down
4 changes: 4 additions & 0 deletions app/services/providers/adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .perplexity_adapter import PerplexityAdapter
from .tensorblock_adapter import TensorblockAdapter
from .zhipu_adapter import ZhipuAdapter
from .vertex_adapter import VertexAdapter


class ProviderAdapterFactory:
Expand Down Expand Up @@ -185,6 +186,9 @@ class ProviderAdapterFactory:
"bedrock": {
"adapter": BedrockAdapter,
},
"vertex": {
"adapter": VertexAdapter,
},
"customized": {
"adapter": OpenAIAdapter,
},
Expand Down
80 changes: 47 additions & 33 deletions app/services/providers/anthropic_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
from collections.abc import AsyncGenerator
from http import HTTPStatus
from typing import Any
from typing import Any, Callable

import aiohttp

Expand All @@ -16,6 +16,8 @@

ANTHROPIC_DEFAULT_MAX_TOKENS = 4096

logger = get_logger(name="anthropic_adapter")


class AnthropicAdapter(ProviderAdapter):
"""Adapter for Anthropic API"""
Expand Down Expand Up @@ -118,22 +120,10 @@ async def list_models(self, api_key: str) -> list[str]:
self.cache_models(api_key, self._base_url, models)

return models

async def process_completion(
self,
endpoint: str,
payload: dict[str, Any],
api_key: str,
) -> Any:
"""Process a completion request using Anthropic API"""
headers = {
"x-api-key": api_key,
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}

# Convert OpenAI format to Anthropic format
streaming = payload.get("stream", False)

@staticmethod
def convert_openai_payload_to_anthropic(payload: dict[str, Any]) -> dict[str, Any]:
"""Convert Anthropic completion payload to OpenAI format"""
anthropic_payload = {
"model": payload["model"],
"max_tokens": payload.get("max_completion_tokens", payload.get("max_tokens", ANTHROPIC_DEFAULT_MAX_TOKENS)),
Expand All @@ -150,7 +140,7 @@ async def process_completion(
for msg in payload["messages"]:
role = msg["role"]
content = msg["content"]
content = self.convert_openai_content_to_anthropic(content)
content = AnthropicAdapter.convert_openai_content_to_anthropic(content)

if role == "system":
# Anthropic requires a system message to be string
Expand All @@ -171,6 +161,25 @@ async def process_completion(
# Handle regular completion (legacy format)
anthropic_payload["prompt"] = f"Human: {payload['prompt']}\n\nAssistant: "

return anthropic_payload

async def process_completion(
self,
endpoint: str,
payload: dict[str, Any],
api_key: str,
) -> Any:
"""Process a completion request using Anthropic API"""
headers = {
"x-api-key": api_key,
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
}

streaming = payload.get("stream", False)
# Convert OpenAI format to Anthropic format
anthropic_payload = self.convert_openai_payload_to_anthropic(payload)

# Choose the appropriate API endpoint - using ternary operator
api_endpoint = "messages" if "messages" in anthropic_payload else "complete"

Expand All @@ -179,17 +188,18 @@ async def process_completion(
# Handle streaming requests
if streaming and "messages" in anthropic_payload:
anthropic_payload["stream"] = True
return await self._stream_anthropic_response(
return await self.stream_anthropic_response(
url, headers, anthropic_payload, payload["model"]
)
else:
# For non-streaming, use the regular approach
return await self._process_regular_response(
return await self.process_regular_response(
url, headers, anthropic_payload, payload["model"]
)

async def _stream_anthropic_response(
self, url, headers, anthropic_payload, model_name
@staticmethod
async def stream_anthropic_response(
url, headers, anthropic_payload, model_name, error_handler: Callable[[str, int], Any] | None = None
):
"""Handle streaming response from Anthropic API, including usage data."""

Expand All @@ -206,12 +216,15 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
):
if response.status != HTTPStatus.OK:
error_text = await response.text()
logger.error(f"Completion Streaming API error for {self.provider_name}: {error_text}")
raise ProviderAPIException(
provider_name=self.provider_name,
error_code=response.status,
error_message=error_text
)
if error_handler:
error_handler(error_text, response.status)
else:
logger.error(f"Completion API error for {error_text}")
raise ProviderAPIException(
provider_name="anthropic",
error_code=response.status,
error_message=error_text,
)

buffer = ""
async for line_bytes in response.content:
Expand Down Expand Up @@ -337,8 +350,9 @@ async def stream_response() -> AsyncGenerator[bytes, None]:

return stream_response()

async def _process_regular_response(
self, url, headers, anthropic_payload, model_name
@staticmethod
async def process_regular_response(
url: str, headers: dict[str, str], anthropic_payload: dict[str, Any], model_name: str, error_handler: Callable[[str, int], Any] | None = None
):
"""Handle regular (non-streaming) response from Anthropic API"""
# Single with statement for multiple contexts
Expand All @@ -348,11 +362,11 @@ async def _process_regular_response(
):
if response.status != HTTPStatus.OK:
error_text = await response.text()
logger.error(f"Completion API error for {self.provider_name}: {error_text}")
logger.error(f"Completion API error for {error_text}")
raise ProviderAPIException(
provider_name=self.provider_name,
provider_name="anthropic",
error_code=response.status,
error_message=error_text
error_message=error_text,
)

anthropic_response = await response.json()
Expand Down
41 changes: 40 additions & 1 deletion app/services/providers/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import json
import time
from abc import ABC, abstractmethod
from typing import Any, ClassVar

# Constants for API key masking
API_KEY_MASK_PREFIX_LENGTH = 2
API_KEY_MASK_SUFFIX_LENGTH = 4
# Minimum length to apply the full prefix + suffix mask (e.g., pr******fix)
# This means if length is > (PREFIX + SUFFIX), we can apply the full rule.
MIN_KEY_LENGTH_FOR_FULL_MASK_LOGIC = (
API_KEY_MASK_PREFIX_LENGTH + API_KEY_MASK_SUFFIX_LENGTH
)


class ProviderAdapter(ABC):
"""Base class for all provider adapters"""
Expand Down Expand Up @@ -66,6 +74,37 @@ def deserialize_api_key_config(serialized_api_key_config: str) -> tuple[str, dic
def mask_config(config: dict[str, Any]) -> dict[str, Any]:
"""Mask the config for the given provider"""
return config

@staticmethod
def mask_api_key(api_key: str) -> str:
"""Mask the API key for the given provider"""
if not api_key:
return None

length = len(api_key)

if length == 0:
return ""

# If key is too short for any meaningful prefix/suffix masking
if length <= API_KEY_MASK_PREFIX_LENGTH:
return "*" * length

# If key is long enough for prefix, but not for prefix + suffix
# e.g., length is 3, 4, 5, 6. For these, show prefix and mask the rest.
if length <= MIN_KEY_LENGTH_FOR_FULL_MASK_LOGIC:
return api_key[:API_KEY_MASK_PREFIX_LENGTH] + "*" * (
length - API_KEY_MASK_PREFIX_LENGTH
)

# If key is long enough for the full prefix + ... + suffix mask
# number of asterisks = length - prefix_length - suffix_length
num_asterisks = length - API_KEY_MASK_PREFIX_LENGTH - API_KEY_MASK_SUFFIX_LENGTH
return (
api_key[:API_KEY_MASK_PREFIX_LENGTH]
+ "*" * num_asterisks
+ api_key[-API_KEY_MASK_SUFFIX_LENGTH:]
)

def cache_models(
self, api_key: str, base_url: str | None, models: list[str]
Expand Down
Loading
Loading