From 3ed4ee99e80c34bfbda226c144af4984f6bd12de Mon Sep 17 00:00:00 2001 From: Lingtong Lu Date: Mon, 14 Jul 2025 12:57:45 -0700 Subject: [PATCH 01/10] Update error handling and log-ability (#9) --- app/api/schemas/openai.py | 33 +++- app/exceptions/exceptions.py | 81 ++++++++- app/main.py | 63 ++++++- app/services/provider_service.py | 62 +++---- app/services/providers/anthropic_adapter.py | 71 ++++---- app/services/providers/azure_adapter.py | 34 +++- app/services/providers/bedrock_adapter.py | 124 +++++++++++--- app/services/providers/cohere_adapter.py | 43 ++++- app/services/providers/google_adapter.py | 161 +++++++++++++----- app/services/providers/openai_adapter.py | 71 +++++--- .../test_provider_service_images.py | 2 +- 11 files changed, 570 insertions(+), 175 deletions(-) diff --git a/app/api/schemas/openai.py b/app/api/schemas/openai.py index a9061a9..60d8d20 100644 --- a/app/api/schemas/openai.py +++ b/app/api/schemas/openai.py @@ -2,6 +2,11 @@ from pydantic import BaseModel +from app.core.logger import get_logger +from app.exceptions.exceptions import InvalidCompletionRequestException + +logger = get_logger(name="openai_schemas") + class OpenAIContentImageUrlModel(BaseModel): url: str @@ -21,17 +26,35 @@ class OpenAIContentModel(BaseModel): def __init__(self, **data: Any): super().__init__(**data) if self.type not in ["text", "image_url", "input_audio"]: - raise ValueError( - f"Invalid type: {self.type}. Must be one of: text, image_url, input_audio" + error_message = f"Invalid type: {self.type}. Must be one of: text, image_url, input_audio" + logger.error(error_message) + raise InvalidCompletionRequestException( + provider_name="openai", + error=ValueError(error_message) ) # Validate that the appropriate field is set based on type if self.type == "text" and self.text is None: - raise ValueError("text field must be set when type is 'text'") + error_message = "text field must be set when type is 'text'" + logger.error(error_message) + raise InvalidCompletionRequestException( + provider_name="openai", + error=ValueError(error_message) + ) if self.type == "image_url" and self.image_url is None: - raise ValueError("image_url field must be set when type is 'image_url'") + error_message = "image_url field must be set when type is 'image_url'" + logger.error(error_message) + raise InvalidCompletionRequestException( + provider_name="openai", + error=ValueError(error_message) + ) if self.type == "input_audio" and self.input_audio is None: - raise ValueError("input_audio field must be set when type is 'input_audio'") + error_message = "input_audio field must be set when type is 'input_audio'" + logger.error(error_message) + raise InvalidCompletionRequestException( + provider_name="openai", + error=ValueError(error_message) + ) class ChatMessage(BaseModel): diff --git a/app/exceptions/exceptions.py b/app/exceptions/exceptions.py index 5e6ebac..112527f 100644 --- a/app/exceptions/exceptions.py +++ b/app/exceptions/exceptions.py @@ -1,6 +1,83 @@ -class InvalidProviderException(Exception): +class BaseForgeException(Exception): + pass + +class InvalidProviderException(BaseForgeException): """Exception raised when a provider is invalid.""" def __init__(self, identifier: str): self.identifier = identifier - super().__init__(f"Provider {identifier} is invalid or failed to extract provider info from model_id {identifier}") \ No newline at end of file + super().__init__(f"Provider {identifier} is invalid or failed to extract provider info from model_id {identifier}") + + +class ProviderAuthenticationException(BaseForgeException): + """Exception raised when a provider authentication fails.""" + + def __init__(self, provider_name: str, error: Exception): + self.provider_name = provider_name + self.error = error + super().__init__(f"Provider {provider_name} authentication failed: {error}") + + +class BaseInvalidProviderSetupException(BaseForgeException): + """Exception raised when a provider setup is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + self.provider_name = provider_name + self.error = error + super().__init__(f"Provider {provider_name} setup is invalid: {error}") + +class InvalidProviderConfigException(BaseInvalidProviderSetupException): + """Exception raised when a provider config is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + super().__init__(provider_name, error) + +class InvalidProviderAPIKeyException(BaseInvalidProviderSetupException): + """Exception raised when a provider API key is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + super().__init__(provider_name, error) + +class ProviderAPIException(BaseForgeException): + """Exception raised when a provider API error occurs.""" + + def __init__(self, provider_name: str, error_code: int, error_message: str): + super().__init__(f"Provider {provider_name} API error: {error_code} {error_message}") + + +class BaseInvalidRequestException(BaseForgeException): + """Exception raised when a request is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + self.provider_name = provider_name + self.error = error + super().__init__(f"Provider {provider_name} request is invalid: {error}") + +class InvalidCompletionRequestException(BaseInvalidRequestException): + """Exception raised when a completion request is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + self.provider_name = provider_name + self.error = error + super().__init__(f"Provider {provider_name} completion request is invalid: {error}") + +class InvalidEmbeddingsRequestException(BaseInvalidRequestException): + """Exception raised when a embeddings request is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + self.provider_name = provider_name + self.error = error + super().__init__(f"Provider {provider_name} embeddings request is invalid: {error}") + +class BaseInvalidForgeKeyException(BaseForgeException): + """Exception raised when a Forge key is invalid.""" + + def __init__(self, error: Exception): + self.error = error + super().__init__(f"Forge key is invalid: {error}") + + +class InvalidForgeKeyException(BaseInvalidForgeKeyException): + """Exception raised when a Forge key is invalid.""" + def __init__(self, error: Exception): + super().__init__(error) \ No newline at end of file diff --git a/app/main.py b/app/main.py index 69b6047..85364a6 100644 --- a/app/main.py +++ b/app/main.py @@ -3,7 +3,7 @@ from collections.abc import Callable from dotenv import load_dotenv -from fastapi import APIRouter, FastAPI, Request +from fastapi import APIRouter, FastAPI, Request, HTTPException from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.base import BaseHTTPMiddleware @@ -20,6 +20,8 @@ from app.core.database import engine from app.core.logger import get_logger from app.models.base import Base +from app.exceptions.exceptions import ProviderAuthenticationException, InvalidProviderException, BaseInvalidProviderSetupException, \ + ProviderAPIException, BaseInvalidRequestException, BaseInvalidForgeKeyException load_dotenv() @@ -77,6 +79,65 @@ async def dispatch(self, request: Request, call_next: Callable): openapi_url="/openapi.json" if not is_production else None, ) +### Exception handlers block ### + +# Add exception handler for ProviderAuthenticationException +@app.exception_handler(ProviderAuthenticationException) +async def provider_authentication_exception_handler(request: Request, exc: ProviderAuthenticationException): + return HTTPException( + status_code=401, + detail=f"Authentication failed for provider {exc.provider_name}" + ) + +# Add exception handler for InvalidProviderException +@app.exception_handler(InvalidProviderException) +async def invalid_provider_exception_handler(request: Request, exc: InvalidProviderException): + return HTTPException( + status_code=400, + detail=f"{str(exc)}. Please verify your provider and model details by calling the /models endpoint or visiting https://tensorblock.co/api-docs/model-ids, and ensure you’re using a valid provider name, model name, and model ID." + ) + +# Add exception handler for BaseInvalidProviderSetupException +@app.exception_handler(BaseInvalidProviderSetupException) +async def base_invalid_provider_setup_exception_handler(request: Request, exc: BaseInvalidProviderSetupException): + return HTTPException( + status_code=400, + detail=str(exc) + ) + +# Add exception handler for ProviderAPIException +@app.exception_handler(ProviderAPIException) +async def provider_api_exception_handler(request: Request, exc: ProviderAPIException): + return HTTPException( + status_code=exc.error_code, + detail=f"Provider API error: {exc.provider_name} {exc.error_code} {exc.error_message}" + ) + +# Add exception handler for BaseInvalidRequestException +@app.exception_handler(BaseInvalidRequestException) +async def base_invalid_request_exception_handler(request: Request, exc: BaseInvalidRequestException): + return HTTPException( + status_code=400, + detail=str(exc) + ) + +# Add exception handler for BaseInvalidForgeKeyException +@app.exception_handler(BaseInvalidForgeKeyException) +async def base_invalid_forge_key_exception_handler(request: Request, exc: BaseInvalidForgeKeyException): + return HTTPException( + status_code=401, + detail=f"Invalid Forge key: {exc.error}" + ) + +# Add exception handler for NotImplementedError +@app.exception_handler(NotImplementedError) +async def not_implemented_error_handler(request: Request, exc: NotImplementedError): + return HTTPException( + status_code=404, + detail=f"Not implemented: {exc}" + ) +### Exception handlers block ends ### + # Middleware to log slow requests @app.middleware("http") diff --git a/app/services/provider_service.py b/app/services/provider_service.py index 6793283..61688e3 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -15,7 +15,7 @@ ) from app.core.logger import get_logger from app.core.security import decrypt_api_key, encrypt_api_key -from app.exceptions.exceptions import InvalidProviderException +from app.exceptions.exceptions import InvalidProviderException, BaseInvalidRequestException, InvalidForgeKeyException from app.models.user import User from app.services.usage_stats_service import UsageStatsService @@ -322,6 +322,7 @@ def _get_provider_info_with_prefix( ) if not matching_provider: + logger.error(f"No matching provider found for {original_model}") raise InvalidProviderException(original_model) provider_data = self.provider_keys[matching_provider] @@ -358,6 +359,7 @@ def _find_provider_for_unprefixed_model( provider_data.get("base_url"), ) + logger.error(f"No matching provider found for {model}") raise InvalidProviderException(model) def _get_provider_info(self, model: str) -> tuple[str, str, str | None]: @@ -365,9 +367,9 @@ def _get_provider_info(self, model: str) -> tuple[str, str, str | None]: Determine the provider based on the model name. """ if not self._keys_loaded: - raise RuntimeError( - "Provider keys must be loaded before calling _get_provider_info. Call _load_provider_keys_async() first." - ) + error_message = "Provider keys must be loaded before calling _get_provider_info. Call _load_provider_keys_async() first." + logger.error(error_message) + raise RuntimeError(error_message) provider_name, model_name_no_prefix = self._extract_provider_name_prefix(model) @@ -485,25 +487,23 @@ async def process_request( model = payload.get("model") if not model: - raise ValueError("Model is required") - - try: - provider_name, actual_model, base_url = self._get_provider_info(model) - - # Enforce scope restriction (case-insensitive). - if allowed_provider_names is not None and ( - provider_name.lower() not in {p.lower() for p in allowed_provider_names} - ): - raise ValueError( - f"API key is not permitted to use provider '{provider_name}'." - ) - except ValueError as e: - # Use parameterized logging to avoid issues if the error message contains braces - logger.error("Error getting provider info for model {}: {}", model, str(e)) - raise ValueError( - f"Invalid model ID: {model}. Please check your model configuration." + error_message = "Model is required" + logger.error(error_message) + raise BaseInvalidRequestException( + provider_name="unknown", + error=ValueError(error_message) ) + provider_name, actual_model, base_url = self._get_provider_info(model) + + # Enforce scope restriction (case-insensitive). + if allowed_provider_names is not None and ( + provider_name.lower() not in {p.lower() for p in allowed_provider_names} + ): + error_message = f"API key is not permitted to use provider '{provider_name}'." + logger.error(error_message) + raise InvalidForgeKeyException(error=ValueError(error_message)) + logger.debug( f"Processing request for provider: {provider_name}, model: {actual_model}, endpoint: {endpoint}" ) @@ -513,7 +513,9 @@ async def process_request( # Get the provider's API key if provider_name not in self.provider_keys: - raise ValueError(f"No API key found for provider {provider_name}") + error_message = f"API key is not permitted to use provider '{provider_name}'." + logger.error(error_message) + raise InvalidForgeKeyException(error=ValueError(error_message)) serialized_api_key_config = self.provider_keys[provider_name]["api_key"] provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls(provider_name) @@ -534,9 +536,9 @@ async def process_request( elif "images/generations" in endpoint: # TODO: we only support openai for now if provider_name != "openai": - raise ValueError( - f"Unsupported endpoint: {endpoint} for provider {provider_name}" - ) + error_message = f"Unsupported endpoint: {endpoint} for provider {provider_name}" + logger.error(error_message) + raise NotImplementedError(error_message) result = await adapter.process_image_generation( endpoint, payload, @@ -545,9 +547,9 @@ async def process_request( elif "images/edits" in endpoint: # TODO: we only support openai for now if provider_name != "openai": - raise NotImplementedError( - f"Unsupported endpoint: {endpoint} for provider {provider_name}" - ) + error_message = f"Unsupported endpoint: {endpoint} for provider {provider_name}" + logger.error(error_message) + raise NotImplementedError(error_message) result = await adapter.process_image_edits( endpoint, payload, @@ -560,7 +562,9 @@ async def process_request( api_key, ) else: - raise ValueError(f"Unsupported endpoint: {endpoint}") + error_message = f"Unsupported endpoint: {endpoint}" + logger.error(error_message) + raise NotImplementedError(error_message) # Track usage statistics if it's not a streaming response if not inspect.isasyncgen(result): diff --git a/app/services/providers/anthropic_adapter.py b/app/services/providers/anthropic_adapter.py index 9cb2c7e..4533d31 100644 --- a/app/services/providers/anthropic_adapter.py +++ b/app/services/providers/anthropic_adapter.py @@ -7,8 +7,13 @@ import aiohttp +from app.core.logger import get_logger +from app.exceptions.exceptions import ProviderAPIException, InvalidCompletionRequestException + from .base import ProviderAdapter +logger = get_logger(name="anthropic_adapter") + ANTHROPIC_DEFAULT_MAX_TOKENS = 4096 @@ -58,21 +63,23 @@ def convert_openai_content_to_anthropic( if isinstance(content, str): return content - try: - result = [] - for msg in content: - _type = msg["type"] - if _type == "text": - result.append({"type": "text", "text": msg["text"]}) - elif _type == "image_url": - result.append( - AnthropicAdapter.convert_openai_image_content_to_anthropic(msg) - ) - else: - raise NotImplementedError(f"{_type} is not supported") - return result - except Exception as e: - raise NotImplementedError("Unsupported content type") from e + result = [] + for msg in content: + _type = msg["type"] + if _type == "text": + result.append({"type": "text", "text": msg["text"]}) + elif _type == "image_url": + result.append( + AnthropicAdapter.convert_openai_image_content_to_anthropic(msg) + ) + else: + error_message = f"{_type} is not supported" + logger.error(error_message) + raise InvalidCompletionRequestException( + provider_name="anthropic", + error=ValueError(error_message) + ) + return result async def list_models(self, api_key: str) -> list[str]: """List all models (verbosely) supported by the provider""" @@ -95,7 +102,12 @@ async def list_models(self, api_key: str) -> list[str]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Anthropic API error: {error_text}") + logger.error(f"List Models API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) resp = await response.json() self.CLAUDE_MODEL_MAPPING = { d["display_name"]: d["id"] for d in resp["data"] @@ -194,8 +206,11 @@ async def stream_response() -> AsyncGenerator[bytes, None]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError( - f"Anthropic API error: {response.status} - {error_text}" + logger.error(f"Completion Streaming API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text ) buffer = "" @@ -212,13 +227,11 @@ async def stream_response() -> AsyncGenerator[bytes, None]: elif line.startswith("data:"): data_str = line[len("data:") :].strip() - # logger.info(f"Anthropic Raw Event - Type: {event_type}, Data String: {data_str}") if not event_type or data_str is None: continue try: data = json.loads(data_str) - # logger.info(f"Anthropic Parsed Data: {data}") openai_chunk = None finish_reason = None # --- Event Processing Logic --- @@ -230,7 +243,6 @@ async def stream_response() -> AsyncGenerator[bytes, None]: captured_input_tokens = message_data["usage"].get( "input_tokens", 0 ) - # logger.info(f"Captured input_tokens: {captured_input_tokens}") captured_output_tokens = message_data["usage"].get( "output_tokens", captured_output_tokens ) @@ -267,7 +279,6 @@ async def stream_response() -> AsyncGenerator[bytes, None]: captured_output_tokens = usage_data_in_delta.get( "output_tokens", captured_output_tokens ) - # logger.info(f"Captured output_tokens from top-level delta usage: {captured_output_tokens}") if captured_input_tokens > 0: usage_info_complete = True @@ -282,7 +293,6 @@ async def stream_response() -> AsyncGenerator[bytes, None]: captured_output_tokens = usage.get( "output_tokens", captured_output_tokens ) - # logger.info(f"Captured usage from stop: in={captured_input_tokens}, out={captured_output_tokens}") if ( captured_input_tokens > 0 and captured_output_tokens > 0 @@ -315,13 +325,11 @@ async def stream_response() -> AsyncGenerator[bytes, None]: yield f"data: {json.dumps(usage_chunk)}\n\n".encode() # Reset flag to prevent duplicate yields usage_info_complete = False - # logger.info(f"Yielded usage chunk: {usage_chunk}") - except json.JSONDecodeError: - # logger.warning(f"Anthropic stream: Failed to parse JSON: {data_str}") + except json.JSONDecodeError as e: + logger.warning(f"Stream API error for {self.provider_name}: Failed to parse JSON: {e}") continue - except Exception: - # logger.error(f"Anthropic stream processing error: {e}", exc_info=True) + except Exception as e: continue # Final SSE message @@ -340,7 +348,12 @@ async def _process_regular_response( ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Anthropic API error: {error_text}") + logger.error(f"Completion API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) anthropic_response = await response.json() diff --git a/app/services/providers/azure_adapter.py b/app/services/providers/azure_adapter.py index 94f7506..04936b3 100644 --- a/app/services/providers/azure_adapter.py +++ b/app/services/providers/azure_adapter.py @@ -3,9 +3,10 @@ from .openai_adapter import OpenAIAdapter from app.core.logger import get_logger +from app.exceptions.exceptions import ProviderAPIException, BaseInvalidProviderSetupException # Configure logging -logger = get_logger(name="openai_adapter") +logger = get_logger(name="azure_adapter") class AzureAdapter(OpenAIAdapter): def __init__(self, provider_name: str, base_url: str, config: dict[str, Any]): @@ -18,15 +19,27 @@ def assert_valid_base_url(base_url: str) -> str: # base_url is required # e.g: https://THE-RESOURCE-NAME.openai.azure.com/ if not base_url: - raise ValueError("Azure base URL is required") + error_text = "Azure base URL is required" + logger.error(error_text) + raise BaseInvalidProviderSetupException( + provider_name="azure", + error=ValueError(error_text) + ) base_url = base_url.rstrip("/") return base_url @staticmethod def serialize_api_key_config(api_key: str, config: dict[str, Any] | None) -> str: """Serialize the API key for the given provider""" - assert config is not None - assert config.get("api_version") is not None + try: + assert config is not None + assert config.get("api_version") is not None + except AssertionError as e: + logger.error(str(e)) + raise BaseInvalidProviderSetupException( + provider_name="azure", + error=e + ) return json.dumps({ "api_key": api_key, @@ -36,9 +49,16 @@ def serialize_api_key_config(api_key: str, config: dict[str, Any] | None) -> str @staticmethod def deserialize_api_key_config(serialized_api_key_config: str) -> tuple[str, dict[str, Any] | None]: """Deserialize the API key for the given provider""" - deserialized_api_key_config = json.loads(serialized_api_key_config) - assert deserialized_api_key_config.get("api_key") is not None - assert deserialized_api_key_config.get("api_version") is not None + try: + deserialized_api_key_config = json.loads(serialized_api_key_config) + assert deserialized_api_key_config.get("api_key") is not None + assert deserialized_api_key_config.get("api_version") is not None + except Exception as e: + logger.error(str(e)) + raise BaseInvalidProviderSetupException( + provider_name="azure", + error=e + ) return deserialized_api_key_config["api_key"], { "api_version": deserialized_api_key_config["api_version"], diff --git a/app/services/providers/bedrock_adapter.py b/app/services/providers/bedrock_adapter.py index e38edf7..63a3f5d 100644 --- a/app/services/providers/bedrock_adapter.py +++ b/app/services/providers/bedrock_adapter.py @@ -8,6 +8,7 @@ from typing import Any from app.core.logger import get_logger +from app.exceptions.exceptions import BaseInvalidProviderSetupException, ProviderAPIException, InvalidCompletionRequestException, BaseForgeException from .base import ProviderAdapter @@ -27,13 +28,32 @@ class BedrockAdapter(ProviderAdapter): def __init__(self, provider_name: str, base_url: str, config: dict[str, str] | None = None): self._provider_name = provider_name self._base_url = base_url - assert "region_name" in config, "Bedrock region_name is required" + self.parse_config(config) + self._session = aiobotocore.session.get_session() + + @staticmethod + def validate_config(config: dict[str, str] | None): + """Validate the config for the given provider""" + + try: + assert config is not None, "Bedrock config is required" + assert "region_name" in config, "Bedrock region_name is required" + assert "aws_access_key_id" in config, "Bedrock aws_access_key_id is required" + assert "aws_secret_access_key" in config, "Bedrock aws_secret_access_key is required" + except AssertionError as e: + logger.error(str(e)) + raise BaseInvalidProviderSetupException( + provider_name="bedrock", + error=e + ) + + def parse_config(self, config: dict[str, str] | None) -> None: + """Parse the config for the given provider""" + + self.validate_config(config) self._region_name = config["region_name"] - assert "aws_access_key_id" in config, "Bedrock aws_access_key_id is required" self._aws_access_key_id = config["aws_access_key_id"] - assert "aws_secret_access_key" in config, "Bedrock aws_secret_access_key is required" self._aws_secret_access_key = config["aws_secret_access_key"] - self._session = aiobotocore.session.get_session() @property def client_ctx(self): @@ -50,10 +70,7 @@ def provider_name(self) -> str: @staticmethod def serialize_api_key_config(api_key: str, config: dict[str, Any] | None) -> str: """Serialize the API key for the given provider""" - assert config is not None - assert config.get("region_name") is not None - assert config.get("aws_access_key_id") is not None - assert config.get("aws_secret_access_key") is not None + BedrockAdapter.validate_config(config) return json.dumps({ "api_key": api_key, @@ -65,11 +82,18 @@ def serialize_api_key_config(api_key: str, config: dict[str, Any] | None) -> str @staticmethod def deserialize_api_key_config(serialized_api_key_config: str) -> tuple[str, dict[str, Any] | None]: """Deserialize the API key for the given provider""" - deserialized_api_key_config = json.loads(serialized_api_key_config) - assert deserialized_api_key_config.get("api_key") is not None - assert deserialized_api_key_config.get("region_name") is not None - assert deserialized_api_key_config.get("aws_access_key_id") is not None - assert deserialized_api_key_config.get("aws_secret_access_key") is not None + try: + deserialized_api_key_config = json.loads(serialized_api_key_config) + assert deserialized_api_key_config.get("api_key") is not None + assert deserialized_api_key_config.get("region_name") is not None + assert deserialized_api_key_config.get("aws_access_key_id") is not None + assert deserialized_api_key_config.get("aws_secret_access_key") is not None + except Exception as e: + logger.error(str(e)) + raise BaseInvalidProviderSetupException( + provider_name="bedrock", + error=e + ) return deserialized_api_key_config["api_key"], { "region_name": deserialized_api_key_config["region_name"], @@ -80,10 +104,7 @@ def deserialize_api_key_config(serialized_api_key_config: str) -> tuple[str, dic @staticmethod def mask_config(config: dict[str, Any] | None) -> dict[str, Any] | None: """Mask the config for the given provider""" - assert config is not None - assert config.get("region_name") is not None - assert config.get("aws_access_key_id") is not None - assert config.get("aws_secret_access_key") is not None + BedrockAdapter.validate_config(config) mask_str = "*" * 7 return { "region_name": config["region_name"][:3] + mask_str + config["region_name"][-3:], @@ -105,7 +126,13 @@ async def list_models(self, api_key: str) -> list[str]: try: response = await bedrock.list_foundation_models() except Exception as e: - raise ValueError(f"Bedrock API error: {e}") + error_text = f"List models API error for {self.provider_name}: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=400, + error_message=error_text + ) models = [r["modelId"] for r in response["modelSummaries"]] @@ -152,9 +179,22 @@ async def convert_openai_image_content_to_bedrock(msg: list[dict[str, Any]] | st }, } } + except aiohttp.ClientResponseError as e: + error_text = f"Bedrock API error: failed to download image from {data_url}: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name="bedrock", + error_code=e.status, + error_message=error_text + ) except Exception as e: - logger.warning(f"Bedrock API error: failed to download image from {data_url}: {e}") - raise Exception(f"Bedrock API error: {e}") + error_text = f"Bedrock API error: failed to download image from {data_url}: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name="bedrock", + error_code=500, + error_message=error_text + ) @staticmethod async def convert_openai_content_to_bedrock(content: list[dict[str, Any]] | str) -> list[dict[str, Any]]: @@ -171,10 +211,22 @@ async def convert_openai_content_to_bedrock(content: list[dict[str, Any]] | str) elif _type == "image_url": result.append(await BedrockAdapter.convert_openai_image_content_to_bedrock(msg)) else: - raise NotImplementedError(f"{_type} is not supported") + error_text = f"Bedrock API request error: {_type} is not supported" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name="bedrock", + error=ValueError(error_text) + ) return result + except BaseForgeException as e: + raise e except Exception as e: - raise NotImplementedError("Unsupported content type") from e + error_text = f"Bedrock API request error: {e}" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name="bedrock", + error=e + ) from e @staticmethod async def convert_openai_payload_to_bedrock(payload: dict[str, Any]) -> dict[str, Any]: @@ -224,7 +276,13 @@ async def _process_regular_response(self, bedrock_payload: dict[str, Any]) -> di **bedrock_payload, ) except Exception as e: - raise ValueError(f"Bedrock API error: {e}") + error_text = f"Completion API error for {self.provider_name}: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=400, + error_message=error_text + ) if message := response.get("output", {}).get("message"): completion_id = f"chatcmpl-{str(uuid.uuid4())}" @@ -239,7 +297,13 @@ async def _process_regular_response(self, bedrock_payload: dict[str, Any]) -> di if _type == "text": text_content += value else: - raise NotImplementedError(f"Bedrock API error: {_type} response is not supported") + error_text = f"Completion API error for {self.provider_name}: {_type} response is not supported" + logger.error(error_text) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=400, + error_message=error_text + ) input_tokens = response.get("usage", {}).get("inputTokens", 0) output_tokens = response.get("usage", {}).get("outputTokens", 0) @@ -365,7 +429,13 @@ async def _process_streaming_response(self, bedrock_payload: dict[str, Any]) -> if openai_chunk: yield f"data: {json.dumps(openai_chunk)}\n\n".encode() except Exception as e: - raise ValueError(f"Bedrock API error: {e}") + error_text = f"Streaming completion API error for {self.provider_name}: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=400, + error_message=error_text + ) if final_usage_data: openai_chunk = { "id": request_id, @@ -382,8 +452,10 @@ async def _process_streaming_response(self, bedrock_payload: dict[str, Any]) -> # Send final [DONE] message yield b"data: [DONE]\n\n" + except BaseForgeException as e: + raise e except Exception as e: - logger.error(f"Bedrock streaming API error: {str(e)}", exc_info=True) + logger.error(f"Streaming completion API error for {self.provider_name}: {e}", exc_info=True) error_chunk = { "id": str(uuid.uuid4()), "object": "chat.completion.chunk", diff --git a/app/services/providers/cohere_adapter.py b/app/services/providers/cohere_adapter.py index 602ce8d..27e120e 100644 --- a/app/services/providers/cohere_adapter.py +++ b/app/services/providers/cohere_adapter.py @@ -8,6 +8,7 @@ import aiohttp from app.core.logger import get_logger +from app.exceptions.exceptions import ProviderAPIException, BaseForgeException from .base import ProviderAdapter @@ -45,7 +46,12 @@ async def list_models(self, api_key: str) -> list[str]: ) as response: if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Cohere API error: {error_text}") + logger.error(f"List models API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) resp = await response.json() models = [d["name"] for d in resp["models"]] @@ -129,8 +135,12 @@ async def _stream_cohere_response( ): if response.status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Cohere API error: {response.status} - {error_text}") - raise ValueError(f"Cohere stream API error: {error_text}") + logger.error(f"Streaming completion API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) # Track the message ID for consistency message_id = None @@ -282,7 +292,7 @@ async def _stream_cohere_response( # # Send final [DONE] message yield b"data: [DONE]\n\n" except Exception as e: - logger.error(f"Cohere streaming API error: {str(e)}", exc_info=True) + logger.error(f"Streaming completion API error for {self.provider_name}: {e}", exc_info=True) error_chunk = { "id": uuid.uuid4(), "object": "chat.completion.chunk", @@ -307,7 +317,12 @@ async def _process_cohere_chat_completion( async with session.post(url, headers=headers, json=payload) as response: if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Cohere API error: {error_text}") + logger.error(f"Completion API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) resp = await response.json() return self._convert_cohere_to_openai(resp, payload["model"]) @@ -389,10 +404,22 @@ async def process_embeddings( async with session.post(url, headers=headers, json=cohere_payload) as response: if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Cohere API error: {error_text}") + logger.error(f"Embeddings API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) response_json = await response.json() return self.convert_cohere_embeddings_response_to_openai(response_json, payload["model"]) + except BaseForgeException as e: + raise e except Exception as e: - logger.error(f"Error in Cohere embeddings: {str(e)}", exc_info=True) - raise + error_text = f"Embeddings API error for {self.provider_name}: {e}" + logger.error(error_text, exc_info=True) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=500, + error_message=error_text + ) diff --git a/app/services/providers/google_adapter.py b/app/services/providers/google_adapter.py index 74228ed..8a44820 100644 --- a/app/services/providers/google_adapter.py +++ b/app/services/providers/google_adapter.py @@ -10,6 +10,8 @@ import aiohttp from app.core.logger import get_logger +from app.exceptions.exceptions import BaseForgeException, BaseInvalidRequestException, ProviderAPIException, InvalidCompletionRequestException, \ + InvalidEmbeddingsRequestException from .base import ProviderAdapter @@ -55,7 +57,12 @@ async def list_models(self, api_key: str) -> list[str]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Google API error: {error_text}") + logger.error(f"List Models API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) resp = await response.json() self.GOOGLE_MODEL_MAPPING = { d["displayName"]: d["name"] for d in resp["models"] @@ -91,8 +98,12 @@ async def upload_file_to_gemini( # First, get the file metadata from the URL async with session.head(file_url) as response: if response.status != HTTPStatus.OK: - raise Exception( - f"Gemini Upload API error: Failed to fetch file metadata from URL: {response.status}" + error_text = await response.text() + logger.error(f"Gemini Upload API error: Failed to fetch file metadata from URL: {error_text}") + raise ProviderAPIException( + provider_name="google", + error_code=response.status, + error_message=error_text ) mime_type = response.headers.get("Content-Type", "application/octet-stream") @@ -116,13 +127,23 @@ async def upload_file_to_gemini( f"{base_url}?key={api_key}", headers=headers, json=metadata ) as response: if response.status != HTTPStatus.OK: - raise Exception( - f"Gemini Upload API error: Failed to initiate upload: {response.status}" + error_text = await response.text() + logger.error(f"Gemini Upload API error: Failed to initiate upload: {error_text}") + raise ProviderAPIException( + provider_name="google", + error_code=response.status, + error_message=error_text ) upload_url = response.headers.get("X-Goog-Upload-URL") if not upload_url: - raise Exception("No upload URL received from server") + error_text = "Gemini Upload API error: No upload URL received from server" + logger.error(error_text) + raise ProviderAPIException( + provider_name="google", + error_code=404, + error_message=error_text + ) # Upload the file content using streaming upload_headers = { @@ -134,16 +155,24 @@ async def upload_file_to_gemini( # Stream the file content directly from the source URL to Gemini API async with session.get(file_url) as source_response: if source_response.status != HTTPStatus.OK: - raise Exception( - f"Gemini Upload API error: Failed to fetch file content: {source_response.status}" + error_text = await source_response.text() + logger.error(f"Gemini Upload API error: Failed to fetch file content: {error_text}") + raise ProviderAPIException( + provider_name="google", + error_code=source_response.status, + error_message=error_text ) async with session.put( upload_url, headers=upload_headers, data=source_response.content ) as upload_response: if upload_response.status != HTTPStatus.OK: - raise Exception( - f"Gemini Upload API error: Failed to upload file: {upload_response.status}" + error_text = await upload_response.text() + logger.error(f"Gemini Upload API error: Failed to upload file: {error_text}") + raise ProviderAPIException( + provider_name="google", + error_code=upload_response.status, + error_message=error_text ) return await upload_response.json() @@ -179,9 +208,16 @@ async def convert_openai_image_content_to_google( "file_uri": result["file"]["uri"], } } - except Exception as e: - logger.error(f"Error uploading image to Google Gemini: {e}") + except ProviderAPIException as e: raise e + except Exception as e: + error_text = f"Error uploading image to Google Gemini: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name="google", + error_code=400, + error_message=error_text + ) @staticmethod async def convert_openai_content_to_google( @@ -205,10 +241,21 @@ async def convert_openai_content_to_google( ) ) else: - raise NotImplementedError(f"{_type} is not supported") + error_text = f"{_type} is not supported" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name="google", + error=ValueError(error_text) + ) return result + except BaseForgeException as e: + raise e except Exception as e: - raise NotImplementedError("Unsupported content type") from e + logger.error(f"Error converting OpenAI content to Google: {e}") + raise BaseInvalidRequestException( + provider_name="google", + error=e + ) async def process_completion( self, @@ -256,10 +303,19 @@ async def _stream_google_response( try: if not google_payload: - raise ValueError("Empty payload for Google API request") + error_text = f"Empty payload for {self.provider_name} API request" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name=self.provider_name, + error=ValueError(error_text) + ) if not api_key: - raise ValueError("Missing API key for Google request") - + error_text = f"Missing API key for {self.provider_name} API request" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name=self.provider_name, + error=ValueError(error_text) + ) headers = {"Content-Type": "application/json", "Accept": "application/json"} logger.debug( f"Google API request - URL: {url}, Payload sample: {str(google_payload)[:200]}..." @@ -273,8 +329,12 @@ async def _stream_google_response( ): if response.status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Google API error: {response.status} - {error_text}") - raise ValueError(f"Google stream API error: {error_text}") + logger.error(f"Completion Streaming API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) # Process response in chunks buffer = "" @@ -432,7 +492,6 @@ async def _process_google_chat_completion( ) -> dict[str, Any]: """Process a regular (non-streaming) chat completion with Google Gemini""" model = payload.get("model", "") - logger.info(f"Processing regular chat completion for model: {model}") # Convert payload to Google format google_payload = await self.convert_openai_completion_payload_to_google(payload, api_key) @@ -441,7 +500,6 @@ async def _process_google_chat_completion( model_path = model if model.startswith("models/") else f"models/{model}" url = f"{self._base_url}/{model_path}:generateContent" - logger.info(f"Sending request to Google API: {url}") try: # Make the API request @@ -449,10 +507,12 @@ async def _process_google_chat_completion( # Check for API key if not api_key: - raise ValueError("Missing API key for Google request") - - logger.debug(f"Google API request - Headers: {headers}") - logger.debug(f"Google API request - URL: {url}") + error_text = f"Missing API key for {self.provider_name} API request" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name=self.provider_name, + error=ValueError(error_text) + ) async with ( aiohttp.ClientSession() as session, @@ -461,25 +521,27 @@ async def _process_google_chat_completion( ) as response, ): response_status = response.status - logger.info(f"Google API response status: {response_status}") - if response_status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Google API error: {response_status} - {error_text}") - - raise ValueError(f"Google API error: {error_text}") + logger.error(f"Completion API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response_status, + error_message=error_text + ) response_json = await response.json() - logger.debug( - f"Google API response: {json.dumps(response_json)[:200]}..." - ) # Convert to OpenAI format return self.convert_google_completion_response_to_openai(response_json, model) - + except BaseForgeException as e: + raise e except Exception as e: logger.error(f"Error in Google chat completion: {str(e)}", exc_info=True) - raise + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=e + ) @staticmethod def convert_google_completion_response_to_openai( @@ -600,7 +662,6 @@ async def process_embeddings( model_path = model if model.startswith("models/") else f"models/{model}" url = f"{self._base_url}/{model_path}:embedContent" - logger.info(f"Sending request to Google API: {url}") # Convert payload to Google format google_payload = self.convert_openai_embeddings_payload_to_google(payload) @@ -610,7 +671,12 @@ async def process_embeddings( # Check for API key if not api_key: - raise ValueError("Missing API key for Google request") + error_text = f"Missing API key for {self.provider_name} API request" + logger.error(error_text) + raise InvalidEmbeddingsRequestException( + provider_name=self.provider_name, + error=ValueError(error_text) + ) async with ( aiohttp.ClientSession() as session, @@ -619,17 +685,22 @@ async def process_embeddings( ) as response, ): response_status = response.status - logger.info(f"Google API response status: {response_status}") - if response_status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Google API error: {response_status} - {error_text}") - - raise ValueError(f"Google API error: {error_text}") + logger.error(f"Embeddings API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response_status, + error_message=error_text + ) response_json = await response.json() return self.convert_google_embeddings_response_to_openai(response_json, model) - + except BaseForgeException as e: + raise e except Exception as e: - logger.error(f"Error in Google embeddings: {str(e)}", exc_info=True) - raise + logger.error(f"Error in {self.provider_name} embeddings: {str(e)}", exc_info=True) + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=e + ) diff --git a/app/services/providers/openai_adapter.py b/app/services/providers/openai_adapter.py index fae3bf3..933778b 100644 --- a/app/services/providers/openai_adapter.py +++ b/app/services/providers/openai_adapter.py @@ -4,6 +4,7 @@ import aiohttp from app.core.logger import get_logger +from app.exceptions.exceptions import ProviderAPIException, BaseInvalidRequestException from .base import ProviderAdapter @@ -27,8 +28,7 @@ def __init__( def provider_name(self) -> str: return self._provider_name - @staticmethod - def get_model_id(payload: dict[str, Any]) -> str: + def get_model_id(self, payload: dict[str, Any]) -> str: """Get the model ID from the payload""" if "id" in payload: return payload["id"] @@ -37,7 +37,11 @@ def get_model_id(payload: dict[str, Any]) -> str: elif "model_id" in payload: return payload["model_id"] else: - raise ValueError("Model ID not found in payload") + logger.error(f"Model ID not found in payload for {self.provider_name}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError("Model ID not found in payload") + ) async def list_models( self, @@ -66,7 +70,12 @@ async def list_models( ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"{self.provider_name} API error: {error_text}") + logger.error(f"List Models API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) resp = await response.json() # Better compatibility with Forge @@ -113,15 +122,17 @@ async def stream_response() -> AsyncGenerator[bytes, None]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError( - f"{self.provider_name} API error: {error_text}" + logger.error(f"Completion Streaming API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text ) # Stream the response back async for chunk in response.content: if self.provider_name == "azure": chunk = self.process_streaming_chunk(chunk) - logger.info(f"OpenAI streaming chunk: {chunk}") if chunk: yield chunk @@ -137,7 +148,12 @@ async def stream_response() -> AsyncGenerator[bytes, None]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"{self.provider_name} API error: {error_text}") + logger.error(f"Completion API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) return await response.json() @@ -161,7 +177,12 @@ async def process_image_generation( ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"{self.provider_name} API error: {error_text}") + logger.error(f"Image Generation API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) return await response.json() @@ -185,7 +206,12 @@ async def process_image_edits( ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"{self.provider_name} API error: {error_text}") + logger.error(f"API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) return await response.json() @@ -211,16 +237,17 @@ async def process_embeddings( url = f"{base_url or self._base_url}/{endpoint}" query_params = query_params or {} - try: - async with ( - aiohttp.ClientSession() as session, - session.post(url, headers=headers, json=payload, params=query_params) as response, - ): - if response.status != HTTPStatus.OK: - error_text = await response.text() - raise ValueError(f"{self.provider_name} API error: {error_text}") + async with ( + aiohttp.ClientSession() as session, + session.post(url, headers=headers, json=payload, params=query_params) as response, + ): + if response.status != HTTPStatus.OK: + error_text = await response.text() + logger.error(f"Embeddings API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) - return await response.json() - except Exception as e: - logger.error(f"Error in OpenAI embeddings: {str(e)}", exc_info=True) - raise + return await response.json() diff --git a/tests/unit_tests/test_provider_service_images.py b/tests/unit_tests/test_provider_service_images.py index 936e8fd..1451585 100644 --- a/tests/unit_tests/test_provider_service_images.py +++ b/tests/unit_tests/test_provider_service_images.py @@ -117,7 +117,7 @@ async def test_process_request_images_generations_routing( mock_anthropic_adapter.process_image_generation.side_effect = ValueError( "Unsupported endpoint: images/generations for provider anthropic" ) - with self.assertRaises(ValueError) as context: + with self.assertRaises(NotImplementedError) as context: await service.process_request( "images/generations", {"model": "claude-3-opus"} ) From 51acdac958ab0b6e2f6bfa6d2751c494f8208149 Mon Sep 17 00:00:00 2001 From: Lingtong Lu Date: Mon, 14 Jul 2025 22:51:36 -0700 Subject: [PATCH 02/10] Update API key filtering to include 'forge-' prefix (#12) --- app/api/routes/proxy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/api/routes/proxy.py b/app/api/routes/proxy.py index 3842021..fa9fbf9 100644 --- a/app/api/routes/proxy.py +++ b/app/api/routes/proxy.py @@ -51,7 +51,7 @@ async def _get_allowed_provider_names( forge_key = ( db.query(ForgeApiKey) .options(joinedload(ForgeApiKey.allowed_provider_keys)) - .filter(ForgeApiKey.key == api_key, ForgeApiKey.is_active) + .filter(ForgeApiKey.key == f"forge-{api_key}", ForgeApiKey.is_active) .first() ) if forge_key is None: From c2030afc289fed2d08cb8f29b5399a70bb9ae780 Mon Sep 17 00:00:00 2001 From: Lingtong Lu Date: Mon, 14 Jul 2025 23:37:46 -0700 Subject: [PATCH 03/10] Remove 'forge-' prefix from API key for caching purposes (#13) --- app/api/routes/proxy.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/app/api/routes/proxy.py b/app/api/routes/proxy.py index fa9fbf9..669973a 100644 --- a/app/api/routes/proxy.py +++ b/app/api/routes/proxy.py @@ -36,6 +36,8 @@ async def _get_allowed_provider_names( from app.api.dependencies import get_api_key_from_headers api_key = await get_api_key_from_headers(request) + # Remove the forge- prefix for caching from the API key + api_key = api_key[6:] allowed = getattr(request.state, "allowed_provider_names", None) if allowed is not None: From cdc7b0ffa4294e49413e7a5a853be6fbbb4f0884 Mon Sep 17 00:00:00 2001 From: Sk <110497538+Dokujaa@users.noreply.github.com> Date: Mon, 14 Jul 2025 23:52:51 -0700 Subject: [PATCH 04/10] Fix cache invalidation bug for Forge key scope updates (#11) --- app/api/routes/api_keys.py | 8 +- app/core/async_cache.py | 23 ++++ app/core/cache.py | 23 ++++ tests/unit_tests/test_cache_invalidation.py | 117 ++++++++++++++++++++ 4 files changed, 170 insertions(+), 1 deletion(-) create mode 100644 tests/unit_tests/test_cache_invalidation.py diff --git a/app/api/routes/api_keys.py b/app/api/routes/api_keys.py index 8be9d6a..b6b8c9b 100644 --- a/app/api/routes/api_keys.py +++ b/app/api/routes/api_keys.py @@ -13,7 +13,7 @@ ForgeApiKeyResponse, ForgeApiKeyUpdate, ) -from app.core.cache import invalidate_provider_service_cache, invalidate_user_cache +from app.core.cache import invalidate_provider_service_cache, invalidate_user_cache, invalidate_forge_scope_cache from app.core.database import get_db from app.core.security import generate_forge_api_key from app.models.forge_api_key import ForgeApiKey @@ -137,6 +137,10 @@ def _update_api_key_internal( db.commit() db.refresh(db_api_key) + # Invalidate forge scope cache if the scope was updated + if api_key_update.allowed_provider_key_ids is not None: + invalidate_forge_scope_cache(db_api_key.key) + response_data = db_api_key.__dict__.copy() response_data["allowed_provider_key_ids"] = [ pk.id for pk in db_api_key.allowed_provider_keys @@ -174,6 +178,7 @@ def _delete_api_key_internal( db.commit() invalidate_user_cache(key_to_invalidate) + invalidate_forge_scope_cache(key_to_invalidate) invalidate_provider_service_cache(current_user.id) return ForgeApiKeyResponse(**response_data) @@ -195,6 +200,7 @@ def _regenerate_api_key_internal( # Invalidate caches for the old key old_key = db_api_key.key invalidate_user_cache(old_key) + invalidate_forge_scope_cache(old_key) invalidate_provider_service_cache(current_user.id) # Generate and set new key diff --git a/app/core/async_cache.py b/app/core/async_cache.py index 43a88f0..c08e222 100644 --- a/app/core/async_cache.py +++ b/app/core/async_cache.py @@ -183,6 +183,29 @@ async def invalidate_user_cache_async(api_key: str) -> None: await async_user_cache.delete(f"user:{api_key}") +async def invalidate_forge_scope_cache_async(api_key: str) -> None: + """Invalidate forge scope cache for a specific API key asynchronously. + + Args: + api_key (str): The API key to invalidate cache for. Can include or exclude 'forge-' prefix. + """ + if not api_key: + return + + # The cache key format uses the API key WITHOUT the "forge-" prefix + # to match how it's set in get_user_by_api_key() + cache_key = api_key + if cache_key.startswith("forge-"): + cache_key = cache_key[6:] # Remove "forge-" prefix to match cache setting format + + await async_provider_service_cache.delete(f"forge_scope:{cache_key}") + + if DEBUG_CACHE: + # Mask the API key for logging + masked_key = cache_key[:8] + "..." if len(cache_key) > 8 else cache_key + logger.debug(f"Cache: Invalidated forge scope cache for API key: {masked_key} (async)") + + async def invalidate_user_cache_by_id_async(user_id: int) -> None: """Invalidate all cache entries for a specific user ID asynchronously""" if not user_id: diff --git a/app/core/cache.py b/app/core/cache.py index b02913f..01e4d15 100644 --- a/app/core/cache.py +++ b/app/core/cache.py @@ -170,6 +170,29 @@ def invalidate_user_cache(api_key: str) -> None: user_cache.delete(f"user:{api_key}") +def invalidate_forge_scope_cache(api_key: str) -> None: + """Invalidate forge scope cache for a specific API key. + + Args: + api_key (str): The API key to invalidate cache for. Can include or exclude 'forge-' prefix. + """ + if not api_key: + return + + # The cache key format uses the API key WITHOUT the "forge-" prefix + # to match how it's set in get_user_by_api_key() + cache_key = api_key + if cache_key.startswith("forge-"): + cache_key = cache_key[6:] # Remove "forge-" prefix to match cache setting format + + provider_service_cache.delete(f"forge_scope:{cache_key}") + + if DEBUG_CACHE: + # Mask the API key for logging + masked_key = cache_key[:8] + "..." if len(cache_key) > 8 else cache_key + logger.debug(f"Cache: Invalidated forge scope cache for API key: {masked_key}") + + # Provider service functions def get_cached_provider_service(user_id: int) -> Any: """Get a provider service from cache by user ID""" diff --git a/tests/unit_tests/test_cache_invalidation.py b/tests/unit_tests/test_cache_invalidation.py new file mode 100644 index 0000000..6b6ac59 --- /dev/null +++ b/tests/unit_tests/test_cache_invalidation.py @@ -0,0 +1,117 @@ +""" +Unit tests for cache invalidation behavior when Forge API key scope is updated. + +Tests the fix for issue #8: Newly added provider not reflected in allowed provider list for Forge key +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.core.async_cache import invalidate_forge_scope_cache_async +from app.core.cache import invalidate_forge_scope_cache + + +class TestForgeKeyCacheInvalidation: + """Test cache invalidation for Forge API keys""" + + def test_invalidate_forge_scope_cache_with_prefix(self): + """Test that cache invalidation works correctly with forge- prefix""" + # Mock the provider_service_cache + with patch('app.core.cache.provider_service_cache') as mock_cache: + # Test with full API key (including forge- prefix) + full_api_key = "forge-abc123def456" + + invalidate_forge_scope_cache(full_api_key) + + # Should strip the "forge-" prefix when creating cache key + expected_cache_key = "forge_scope:abc123def456" + mock_cache.delete.assert_called_once_with(expected_cache_key) + + def test_invalidate_forge_scope_cache_without_prefix(self): + """Test that cache invalidation works correctly without forge- prefix""" + # Mock the provider_service_cache + with patch('app.core.cache.provider_service_cache') as mock_cache: + # Test with stripped API key (without forge- prefix) + stripped_api_key = "abc123def456" + + invalidate_forge_scope_cache(stripped_api_key) + + # Should use the key as-is when creating cache key + expected_cache_key = "forge_scope:abc123def456" + mock_cache.delete.assert_called_once_with(expected_cache_key) + + def test_invalidate_forge_scope_cache_empty_key(self): + """Test that cache invalidation handles empty keys gracefully""" + # Mock the provider_service_cache + with patch('app.core.cache.provider_service_cache') as mock_cache: + # Test with empty API key + invalidate_forge_scope_cache("") + + # Should not call delete for empty keys + mock_cache.delete.assert_not_called() + + def test_invalidate_forge_scope_cache_none_key(self): + """Test that cache invalidation handles None keys gracefully""" + # Mock the provider_service_cache + with patch('app.core.cache.provider_service_cache') as mock_cache: + # Test with None API key + invalidate_forge_scope_cache(None) + + # Should not call delete for None keys + mock_cache.delete.assert_not_called() + + @pytest.mark.asyncio + async def test_invalidate_forge_scope_cache_async_with_prefix(self): + """Test that async cache invalidation works correctly with forge- prefix""" + # Mock the async_provider_service_cache + with patch('app.core.async_cache.async_provider_service_cache') as mock_cache: + mock_cache.delete = AsyncMock() + + # Test with full API key (including forge- prefix) + full_api_key = "forge-abc123def456" + + await invalidate_forge_scope_cache_async(full_api_key) + + # Should strip the "forge-" prefix when creating cache key + expected_cache_key = "forge_scope:abc123def456" + mock_cache.delete.assert_called_once_with(expected_cache_key) + + @pytest.mark.asyncio + async def test_invalidate_forge_scope_cache_async_without_prefix(self): + """Test that async cache invalidation works correctly without forge- prefix""" + # Mock the async_provider_service_cache + with patch('app.core.async_cache.async_provider_service_cache') as mock_cache: + mock_cache.delete = AsyncMock() + + # Test with stripped API key (without forge- prefix) + stripped_api_key = "abc123def456" + + await invalidate_forge_scope_cache_async(stripped_api_key) + + # Should use the key as-is when creating cache key + expected_cache_key = "forge_scope:abc123def456" + mock_cache.delete.assert_called_once_with(expected_cache_key) + + def test_cache_key_format_consistency(self): + """Test that cache invalidation uses the same key format as cache setting""" + # This test verifies the fix for issue #8 + # The bug was that cache was set with stripped key but invalidated with full key + + with patch('app.core.cache.provider_service_cache') as mock_cache: + # Simulate the DB key format (with forge- prefix) + db_api_key = "forge-d8fc7c26e350771b28fe94b7" + + # When we invalidate using the DB key + invalidate_forge_scope_cache(db_api_key) + + # It should create the same cache key format used by get_user_by_api_key + # which strips the forge- prefix: api_key = api_key_from_header[6:] + stripped_key = db_api_key[6:] # Remove "forge-" prefix + expected_cache_key = f"forge_scope:{stripped_key}" + + mock_cache.delete.assert_called_once_with(expected_cache_key) + + # Verify the exact cache key format + assert expected_cache_key == "forge_scope:d8fc7c26e350771b28fe94b7" \ No newline at end of file From 39c451d73daa8265790e471e87616a350d192c73 Mon Sep 17 00:00:00 2001 From: Lingtong Lu Date: Wed, 16 Jul 2025 21:22:21 -0700 Subject: [PATCH 05/10] Update the cache system (#14) --- app/api/dependencies.py | 10 +++--- app/api/routes/api_keys.py | 68 +++++++++++++++++++------------------- app/core/async_cache.py | 61 +++++++++++++++++++++++++++++++--- app/core/cache.py | 34 +++++++++++++------ 4 files changed, 118 insertions(+), 55 deletions(-) diff --git a/app/api/dependencies.py b/app/api/dependencies.py index 6519b3a..dcd6084 100644 --- a/app/api/dependencies.py +++ b/app/api/dependencies.py @@ -17,10 +17,11 @@ from app.api.schemas.user import TokenData from app.core.async_cache import ( - async_provider_service_cache, cache_user_async, get_cached_user_async, invalidate_user_cache_async, + forge_scope_cache_async, + get_forge_scope_cache_async, ) from app.core.database import get_db from app.core.logger import get_logger @@ -180,8 +181,7 @@ async def get_user_by_api_key( # Try scope cache first – this doesn't remove the need to verify the key, but it # avoids an extra query later in /models. - scope_cache_key = f"forge_scope:{api_key}" - cached_scope = await async_provider_service_cache.get(scope_cache_key) + cached_scope = await get_forge_scope_cache_async(api_key) api_key_record = ( db.query(ForgeApiKey) @@ -219,9 +219,7 @@ async def get_user_by_api_key( pk.provider_name for pk in api_key_record.allowed_provider_keys ] # Cache it (short TTL – scope changes are rare) - await async_provider_service_cache.set( - scope_cache_key, allowed_provider_names, ttl=300 - ) + await forge_scope_cache_async(api_key, allowed_provider_names, ttl=300) else: allowed_provider_names = cached_scope diff --git a/app/api/routes/api_keys.py b/app/api/routes/api_keys.py index b6b8c9b..723a34e 100644 --- a/app/api/routes/api_keys.py +++ b/app/api/routes/api_keys.py @@ -13,7 +13,7 @@ ForgeApiKeyResponse, ForgeApiKeyUpdate, ) -from app.core.cache import invalidate_provider_service_cache, invalidate_user_cache, invalidate_forge_scope_cache +from app.core.async_cache import invalidate_forge_scope_cache_async, invalidate_user_cache_async, invalidate_provider_service_cache_async from app.core.database import get_db from app.core.security import generate_forge_api_key from app.models.forge_api_key import ForgeApiKey @@ -25,7 +25,7 @@ # --- Internal Service Functions --- -def _get_api_keys_internal( +async def _get_api_keys_internal( db: Session, current_user: UserModel ) -> list[ForgeApiKeyMasked]: """ @@ -47,7 +47,7 @@ def _get_api_keys_internal( return masked_keys -def _create_api_key_internal( +async def _create_api_key_internal( api_key_create: ForgeApiKeyCreate, db: Session, current_user: UserModel ) -> ForgeApiKeyResponse: """ @@ -91,7 +91,7 @@ def _create_api_key_internal( return ForgeApiKeyResponse(**response_data) -def _update_api_key_internal( +async def _update_api_key_internal( key_id: int, api_key_update: ForgeApiKeyUpdate, db: Session, current_user: UserModel ) -> ForgeApiKeyResponse: """ @@ -112,7 +112,7 @@ def _update_api_key_internal( old_active_state = db_api_key.is_active db_api_key.is_active = update_data["is_active"] if old_active_state and not db_api_key.is_active: - invalidate_user_cache(db_api_key.key) + await invalidate_user_cache_async(db_api_key.key) if api_key_update.allowed_provider_key_ids is not None: db_api_key.allowed_provider_keys.clear() @@ -139,7 +139,7 @@ def _update_api_key_internal( # Invalidate forge scope cache if the scope was updated if api_key_update.allowed_provider_key_ids is not None: - invalidate_forge_scope_cache(db_api_key.key) + await invalidate_forge_scope_cache_async(db_api_key.key) response_data = db_api_key.__dict__.copy() response_data["allowed_provider_key_ids"] = [ @@ -148,7 +148,7 @@ def _update_api_key_internal( return ForgeApiKeyResponse(**response_data) -def _delete_api_key_internal( +async def _delete_api_key_internal( key_id: int, db: Session, current_user: UserModel ) -> ForgeApiKeyResponse: """ @@ -177,13 +177,13 @@ def _delete_api_key_internal( db.delete(db_api_key) db.commit() - invalidate_user_cache(key_to_invalidate) - invalidate_forge_scope_cache(key_to_invalidate) - invalidate_provider_service_cache(current_user.id) + await invalidate_user_cache_async(key_to_invalidate) + await invalidate_forge_scope_cache_async(key_to_invalidate) + await invalidate_provider_service_cache_async(current_user.id) return ForgeApiKeyResponse(**response_data) -def _regenerate_api_key_internal( +async def _regenerate_api_key_internal( key_id: int, db: Session, current_user: UserModel ) -> ForgeApiKeyResponse: """ @@ -199,9 +199,9 @@ def _regenerate_api_key_internal( # Invalidate caches for the old key old_key = db_api_key.key - invalidate_user_cache(old_key) - invalidate_forge_scope_cache(old_key) - invalidate_provider_service_cache(current_user.id) + await invalidate_user_cache_async(old_key) + await invalidate_forge_scope_cache_async(old_key) + await invalidate_provider_service_cache_async(current_user.id) # Generate and set new key new_key_value = generate_forge_api_key() @@ -221,91 +221,91 @@ def _regenerate_api_key_internal( @router.get("/", response_model=list[ForgeApiKeyMasked]) -def get_api_keys( +async def get_api_keys( db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _get_api_keys_internal(db, current_user) + return await _get_api_keys_internal(db, current_user) @router.post("/", response_model=ForgeApiKeyResponse) -def create_api_key( +async def create_api_key( api_key_create: ForgeApiKeyCreate, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _create_api_key_internal(api_key_create, db, current_user) + return await _create_api_key_internal(api_key_create, db, current_user) @router.put("/{key_id}", response_model=ForgeApiKeyResponse) -def update_api_key( +async def update_api_key( key_id: int, api_key_update: ForgeApiKeyUpdate, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _update_api_key_internal(key_id, api_key_update, db, current_user) + return await _update_api_key_internal(key_id, api_key_update, db, current_user) @router.delete("/{key_id}", response_model=ForgeApiKeyResponse) -def delete_api_key( +async def delete_api_key( key_id: int, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _delete_api_key_internal(key_id, db, current_user) + return await _delete_api_key_internal(key_id, db, current_user) @router.post("/{key_id}/regenerate", response_model=ForgeApiKeyResponse) -def regenerate_api_key( +async def regenerate_api_key( key_id: int, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _regenerate_api_key_internal(key_id, db, current_user) + return await _regenerate_api_key_internal(key_id, db, current_user) # Clerk versions of the routes @router.get("/clerk", response_model=list[ForgeApiKeyMasked]) -def get_api_keys_clerk( +async def get_api_keys_clerk( db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _get_api_keys_internal(db, current_user) + return await _get_api_keys_internal(db, current_user) @router.post("/clerk", response_model=ForgeApiKeyResponse) -def create_api_key_clerk( +async def create_api_key_clerk( api_key_create: ForgeApiKeyCreate, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _create_api_key_internal(api_key_create, db, current_user) + return await _create_api_key_internal(api_key_create, db, current_user) @router.put("/clerk/{key_id}", response_model=ForgeApiKeyResponse) -def update_api_key_clerk( +async def update_api_key_clerk( key_id: int, api_key_update: ForgeApiKeyUpdate, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _update_api_key_internal(key_id, api_key_update, db, current_user) + return await _update_api_key_internal(key_id, api_key_update, db, current_user) @router.delete("/clerk/{key_id}", response_model=ForgeApiKeyResponse) -def delete_api_key_clerk( +async def delete_api_key_clerk( key_id: int, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _delete_api_key_internal(key_id, db, current_user) + return await _delete_api_key_internal(key_id, db, current_user) @router.post("/clerk/{key_id}/regenerate", response_model=ForgeApiKeyResponse) -def regenerate_api_key_clerk( +async def regenerate_api_key_clerk( key_id: int, db: Session = Depends(get_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _regenerate_api_key_internal(key_id, db, current_user) + return await _regenerate_api_key_internal(key_id, db, current_user) diff --git a/app/core/async_cache.py b/app/core/async_cache.py index c08e222..a3e34e6 100644 --- a/app/core/async_cache.py +++ b/app/core/async_cache.py @@ -159,9 +159,12 @@ async def wrapper(*args, **kwargs): # User-specific functions async def get_cached_user_async(api_key: str) -> CachedUser | None: - """Get a user from cache by API key asynchronously""" + """Get a user from cache by Forge API key asynchronously""" if not api_key: return None + # Remove the forge- prefix for caching from the API key + if api_key.startswith("forge-"): + api_key = api_key[6:] cached_data = await async_user_cache.get(f"user:{api_key}") if cached_data: return CachedUser.model_validate(cached_data) @@ -169,17 +172,23 @@ async def get_cached_user_async(api_key: str) -> CachedUser | None: async def cache_user_async(api_key: str, user: User) -> None: - """Cache a user by API key asynchronously""" + """Cache a user by Forge API key asynchronously""" if not api_key or user is None: return cached_user = CachedUser.model_validate(user) + # Remove the forge- prefix for caching from the API key + if api_key.startswith("forge-"): + api_key = api_key[6:] await async_user_cache.set(f"user:{api_key}", cached_user.model_dump()) async def invalidate_user_cache_async(api_key: str) -> None: - """Invalidate user cache for a specific API key asynchronously""" + """Invalidate user cache for a specific Forge API key asynchronously""" if not api_key: return + # Remove the forge- prefix for caching from the API key + if api_key.startswith("forge-"): + api_key = api_key[6:] await async_user_cache.delete(f"user:{api_key}") @@ -246,6 +255,51 @@ async def invalidate_user_cache_by_id_async(user_id: int) -> None: if DEBUG_CACHE: logger.debug(f"Cache: Invalidated user cache for key: {key[:8]}...") +async def get_forge_scope_cache_async(api_key: str) -> list[str] | None: + """Get the forge scope cache for a specific Forge API key asynchronously""" + if not api_key: + return None + # Remove the forge- prefix for caching from the API key + cache_key = api_key + if cache_key.startswith("forge-"): + cache_key = cache_key[6:] + return await async_provider_service_cache.get(f"forge_scope:{cache_key}") + + +async def forge_scope_cache_async(api_key: str, allowed_provider_names: list[str], ttl: int = 300) -> None: + """Cache the forge scope cache for a specific Forge API key asynchronously""" + if not api_key: + return None + # Remove the forge- prefix for caching from the API key + cache_key = api_key + if cache_key.startswith("forge-"): + cache_key = cache_key[6:] + await async_provider_service_cache.set(f"forge_scope:{cache_key}", allowed_provider_names, ttl=ttl) + if DEBUG_CACHE: + # Mask the API key for logging + masked_key = cache_key[:8] + "..." if len(cache_key) > 8 else cache_key + logger.debug(f"Cache: set forge scope cache for Forge API key: {masked_key} (async)") + + +async def invalidate_forge_scope_cache_async(api_key: str) -> None: + """Invalidate forge scope cache for a specific API key asynchronously. + + Args: + api_key (str): The API key to invalidate cache for. Can include or exclude 'forge-' prefix. + """ + if not api_key: + return + + cache_key = api_key + if cache_key.startswith("forge-"): + cache_key = cache_key[6:] # Remove "forge-" prefix to match cache setting format + + await async_provider_service_cache.delete(f"forge_scope:{cache_key}") + + if DEBUG_CACHE: + # Mask the API key for logging + masked_key = cache_key[:8] + "..." if len(cache_key) > 8 else cache_key + logger.debug(f"Cache: Invalidated forge scope cache for Forge API key: {masked_key} (async)") # Provider service functions async def get_cached_provider_service_async(user_id: int) -> Any: @@ -355,7 +409,6 @@ async def warm_cache_async(db: Session) -> None: .all() ) for key in forge_api_keys: - # Cache user with their Forge API key await cache_user_async(key.key, user) # Cache provider services for active users diff --git a/app/core/cache.py b/app/core/cache.py index 01e4d15..26f5693 100644 --- a/app/core/cache.py +++ b/app/core/cache.py @@ -146,9 +146,12 @@ def wrapper(*args, **kwargs): # User-specific functions def get_cached_user(api_key: str) -> CachedUser | None: - """Get a user from cache by API key""" + """Get a user from cache by Forge API key""" if not api_key: return None + # Remove the forge- prefix for caching from the API key + if api_key.startswith("forge-"): + api_key = api_key[6:] cached_data = user_cache.get(f"user:{api_key}") if cached_data: return CachedUser.model_validate(cached_data) @@ -156,17 +159,23 @@ def get_cached_user(api_key: str) -> CachedUser | None: def cache_user(api_key: str, user: User) -> None: - """Cache a user by API key""" + """Cache a user by Forge API key""" if not api_key or user is None: return cached_user = CachedUser.model_validate(user) + # Remove the forge- prefix for caching from the API key + if api_key.startswith("forge-"): + api_key = api_key[6:] user_cache.set(f"user:{api_key}", cached_user.model_dump()) def invalidate_user_cache(api_key: str) -> None: - """Invalidate user cache for a specific API key""" + """Invalidate user cache for a specific Forge API key""" if not api_key: return + # Remove the forge- prefix for caching from the API key + if api_key.startswith("forge-"): + api_key = api_key[6:] user_cache.delete(f"user:{api_key}") @@ -190,7 +199,7 @@ def invalidate_forge_scope_cache(api_key: str) -> None: if DEBUG_CACHE: # Mask the API key for logging masked_key = cache_key[:8] + "..." if len(cache_key) > 8 else cache_key - logger.debug(f"Cache: Invalidated forge scope cache for API key: {masked_key}") + logger.debug(f"Cache: Invalidated forge scope cache for Forge API key: {masked_key}") # Provider service functions @@ -323,7 +332,7 @@ def invalidate_all_caches() -> None: async def warm_cache(db: Session) -> None: """Pre-cache frequently accessed data""" from app.core.security import decrypt_api_key - from app.models.provider_key import ProviderKey + from app.models.forge_api_key import ForgeApiKey from app.models.user import User from app.services.provider_service import ProviderService @@ -333,12 +342,15 @@ async def warm_cache(db: Session) -> None: # Cache active users active_users = db.query(User).filter(User.is_active).all() for user in active_users: - # Get user's API keys - api_keys = db.query(ProviderKey).filter(ProviderKey.user_id == user.id).all() - for key in api_keys: - # Decrypt the API key before caching - decrypted_key = decrypt_api_key(key.encrypted_api_key) - cache_user(decrypted_key, user) + # Get user's Forge API keys + forge_api_keys = ( + db.query(ForgeApiKey) + .filter(ForgeApiKey.user_id == user.id, ForgeApiKey.is_active) + .all() + ) + for key in forge_api_keys: + # Cache user with their Forge API key + cache_user(key.key, user) # Cache provider services for active users for user in active_users: From 48a553ec29dcb2d1579ed107abee57ff0c466965 Mon Sep 17 00:00:00 2001 From: Yiming Cheng <84763321+EaminC@users.noreply.github.com> Date: Sat, 19 Jul 2025 11:40:28 -0500 Subject: [PATCH 06/10] [docs]:Add Usage Statistics API documentation (#17) --- docs/user_guide.md | 84 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/docs/user_guide.md b/docs/user_guide.md index 9c79ec4..418dfef 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -21,6 +21,7 @@ This guide provides detailed instructions for using Forge, the AI provider middl - [API Integration](#api-integration) - [OpenAI API Compatibility](#openai-api-compatibility) - [Chat Completions](#chat-completions) + - [Usage Statistics](#usage-statistics) - [Model Support](#model-support) - [Advanced Features](#advanced-features) - [Custom Model Mapping](#custom-model-mapping) @@ -211,6 +212,89 @@ curl -X POST http://localhost:8000/chat/completions \ For detailed API documentation, visit the Swagger UI at `http://localhost:8000/docs` when the server is running. +### Usage Statistics + +To monitor your API usage and track consumption across different providers and models: + +```bash +curl -X GET "http://localhost:8000/v1/stats/" \ + -H "Authorization: Bearer your_forge_api_key" +``` + +#### Query Parameters + +You can filter usage statistics using the following parameters: + +- `provider`: Filter by provider name (e.g., "OpenAI", "Azure", "Anthropic") +- `model`: Filter by model name (e.g., "gpt-4.1", "claude-3") +- `start_date`: Start date for filtering (YYYY-MM-DD format) +- `end_date`: End date for filtering (YYYY-MM-DD format) + +#### Example Queries + +**Get usage for a specific provider:** + +```bash +curl -X GET "http://localhost:8000/v1/stats/?provider=OpenAI" \ + -H "Authorization: Bearer your_forge_api_key" +``` + +**Get usage for a specific model:** + +```bash +curl -X GET "http://localhost:8000/v1/stats/?model=gpt-4.1" \ + -H "Authorization: Bearer your_forge_api_key" +``` + +**Get usage for a date range:** + +```bash +curl -X GET "http://localhost:8000/v1/stats/?start_date=2024-01-01&end_date=2024-01-31" \ + -H "Authorization: Bearer your_forge_api_key" +``` + +**Combine multiple filters:** + +```bash +curl -X GET "http://localhost:8000/v1/stats/?provider=OpenAI&model=gpt-4.1&start_date=2024-01-01" \ + -H "Authorization: Bearer your_forge_api_key" +``` + +#### Response Format + +The API returns a JSON array with usage statistics: + +```json + { + "provider_name": "OpenAI", + "model": "gpt-4", + "input_tokens": 10000, + "output_tokens": 5000, + "total_tokens": 15000, + "requests_count": 15, + "cost": 0.0 + } +``` + +```` + +**Response Fields:** + +- `provider_name`: The AI provider name +- `model`: The specific model used +- `input_tokens`: Number of input tokens consumed +- `output_tokens`: Number of output tokens generated +- `total_tokens`: Total tokens (input + output) +- `requests_count`: Number of API requests made +- `cost`: Estimated cost (if available) + +This is useful for: + +- Monitoring API usage and costs +- Tracking consumption across different providers +- Analyzing usage patterns by model or time period +- Budget planning and resource allocation + ### Model Support Forge supports models from various providers: From 13920d407a596a1411d21377dded7cd55cf31c36 Mon Sep 17 00:00:00 2001 From: Lingtong Lu Date: Sat, 19 Jul 2025 22:35:44 -0700 Subject: [PATCH 07/10] Refactor database interactions to use AsyncSession and switch to async cache (#16) --- ...34d338f7_update_model_mapping_type_for_.py | 42 ++ app/api/dependencies.py | 69 +-- app/api/routes/api_auth.py | 8 +- app/api/routes/api_keys.py | 98 ++-- app/api/routes/auth.py | 22 +- app/api/routes/provider_keys.py | 491 ++++++++---------- app/api/routes/proxy.py | 34 +- app/api/routes/stats.py | 16 +- app/api/routes/users.py | 87 ++-- app/api/routes/webhooks.py | 221 ++++---- app/api/schemas/provider_key.py | 1 + app/core/async_cache.py | 16 +- app/core/cache.py | 6 +- app/core/database.py | 70 ++- app/models/provider_key.py | 4 +- app/services/provider_service.py | 110 ++-- app/services/usage_stats_service.py | 20 +- forge-cli.py | 61 ++- pyproject.toml | 2 +- tests/cache/test_async_cache.py | 98 ++-- tests/cache/test_sync_cache.py | 374 ------------- tests/mock_testing/add_mock_provider.py | 159 +++--- tests/mock_testing/setup_test_user.py | 205 ++++---- tests/unit_tests/test_provider_service.py | 49 +- .../test_provider_service_images.py | 38 +- tools/diagnostics/fix_model_mapping.py | 14 +- uv.lock | 9 +- 27 files changed, 1033 insertions(+), 1291 deletions(-) create mode 100644 alembic/versions/9daf34d338f7_update_model_mapping_type_for_.py delete mode 100755 tests/cache/test_sync_cache.py diff --git a/alembic/versions/9daf34d338f7_update_model_mapping_type_for_.py b/alembic/versions/9daf34d338f7_update_model_mapping_type_for_.py new file mode 100644 index 0000000..4428b60 --- /dev/null +++ b/alembic/versions/9daf34d338f7_update_model_mapping_type_for_.py @@ -0,0 +1,42 @@ +"""update model_mapping type for ProviderKey table + +Revision ID: 9daf34d338f7 +Revises: 08cc005a4bc8 +Create Date: 2025-07-18 21:32:48.791253 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision = "9daf34d338f7" +down_revision = "08cc005a4bc8" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Change model_mapping column from String to JSON + op.alter_column( + "provider_keys", + "model_mapping", + existing_type=sa.String(), + type_=postgresql.JSON(astext_type=sa.Text()), + existing_nullable=True, + postgresql_using="model_mapping::json", + ) + + +def downgrade() -> None: + # Change model_mapping column from JSON back to String + op.alter_column( + "provider_keys", + "model_mapping", + existing_type=postgresql.JSON(astext_type=sa.Text()), + type_=sa.String(), + existing_nullable=True, + postgresql_using="model_mapping::text", + ) diff --git a/app/api/dependencies.py b/app/api/dependencies.py index dcd6084..f8333ae 100644 --- a/app/api/dependencies.py +++ b/app/api/dependencies.py @@ -12,8 +12,10 @@ from fastapi import Depends, HTTPException, Request, status from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from jose import JWTError, jwt +from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session, joinedload +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session, joinedload, selectinload from app.api.schemas.user import TokenData from app.core.async_cache import ( @@ -23,7 +25,7 @@ forge_scope_cache_async, get_forge_scope_cache_async, ) -from app.core.database import get_db +from app.core.database import get_db, get_async_db from app.core.logger import get_logger from app.core.security import ( ALGORITHM, @@ -91,7 +93,7 @@ async def fetch_and_cache_jwks() -> list | None: async def get_current_user( - db: Session = Depends(get_db), token: str = Depends(oauth2_scheme) + db: AsyncSession = Depends(get_async_db), token: str = Depends(oauth2_scheme) ): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -106,7 +108,9 @@ async def get_current_user( token_data = TokenData(username=username) except JWTError as err: raise credentials_exception from err - user = db.query(User).filter(User.username == token_data.username).first() + + result = await db.execute(select(User).filter(User.username == token_data.username)) + user = result.scalar_one_or_none() if user is None: raise credentials_exception return user @@ -143,7 +147,7 @@ async def get_api_key_from_headers(request: Request) -> str: async def get_user_by_api_key( request: Request = None, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> User: """Get user by API key from headers, with caching""" api_key_from_header = await get_api_key_from_headers(request) @@ -183,12 +187,12 @@ async def get_user_by_api_key( # avoids an extra query later in /models. cached_scope = await get_forge_scope_cache_async(api_key) - api_key_record = ( - db.query(ForgeApiKey) - .options(joinedload(ForgeApiKey.allowed_provider_keys)) + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) .filter(ForgeApiKey.key == api_key_from_header, ForgeApiKey.is_active) - .first() ) + api_key_record = result.scalar_one_or_none() if not api_key_record: raise HTTPException( @@ -197,12 +201,12 @@ async def get_user_by_api_key( ) # Get the user associated with this API key and EAGER LOAD all provider keys - user = ( - db.query(User) - .options(joinedload(User.provider_keys)) + result = await db.execute( + select(User) + .options(selectinload(User.provider_keys)) .filter(User.id == api_key_record.user_id) - .first() ) + user = result.scalar_one_or_none() if not user: raise HTTPException( @@ -230,7 +234,7 @@ async def get_user_by_api_key( # Update last used timestamp for the API key api_key_record.last_used_at = datetime.utcnow() - db.commit() + await db.commit() # Cache the user data for future requests await cache_user_async(api_key, user) @@ -338,7 +342,7 @@ async def validate_clerk_jwt(token: str = Depends(clerk_token_header)): async def get_current_user_from_clerk( - db: Session = Depends(get_db), token_payload: dict = Depends(validate_clerk_jwt) + db: AsyncSession = Depends(get_async_db), token_payload: dict = Depends(validate_clerk_jwt) ): """Get the current user from Clerk token, creating if needed""" from urllib.parse import quote @@ -352,7 +356,8 @@ async def get_current_user_from_clerk( ) # Find user by clerk_user_id - user = db.query(User).filter(User.clerk_user_id == clerk_user_id).first() + result = await db.execute(select(User).filter(User.clerk_user_id == clerk_user_id)) + user = result.scalar_one_or_none() # User doesn't exist yet, create one if not user: @@ -398,7 +403,8 @@ async def get_current_user_from_clerk( username = email # Check if username exists and make unique if needed - existing_user = db.query(User).filter(User.username == username).first() + result = await db.execute(select(User).filter(User.username == username)) + existing_user = result.scalar_one_or_none() if existing_user: import random @@ -412,20 +418,22 @@ async def get_current_user_from_clerk( username = clerk_user_id # Check if user exists with this email - existing_user = db.query(User).filter(User.email == email).first() + result = await db.execute(select(User).filter(User.email == email)) + existing_user = result.scalar_one_or_none() if existing_user: # Link existing user to Clerk ID try: existing_user.clerk_user_id = clerk_user_id - db.commit() + await db.commit() return existing_user except IntegrityError: # Another request might have already linked this user or created a new one - db.rollback() + await db.rollback() # Retry the query to get the user - user = ( - db.query(User).filter(User.clerk_user_id == clerk_user_id).first() + result = await db.execute( + select(User).filter(User.clerk_user_id == clerk_user_id) ) + user = result.scalar_one_or_none() if user: return user # If still no user, continue with creation attempt @@ -439,17 +447,15 @@ async def get_current_user_from_clerk( username=username, clerk_user_id=clerk_user_id, is_active=True, - hashed_password=get_password_hash( - "CLERK_AUTH_USER" - ), # Add placeholder password for Clerk users + hashed_password="", # Clerk handles authentication ) db.add(user) - db.commit() - db.refresh(user) + await db.commit() + await db.refresh(user) # Create default TensorBlock provider for the new user try: - create_default_tensorblock_provider_for_user(user.id, db) + await create_default_tensorblock_provider_for_user(user.id, db) except Exception as e: # Log error but don't fail user creation logger.warning( @@ -459,12 +465,13 @@ async def get_current_user_from_clerk( return user except IntegrityError as e: # Handle race condition: another request might have created the user - db.rollback() + await db.rollback() if "users_clerk_user_id_key" in str(e) or "clerk_user_id" in str(e): # Retry the query to get the user that was created by another request - user = ( - db.query(User).filter(User.clerk_user_id == clerk_user_id).first() + result = await db.execute( + select(User).filter(User.clerk_user_id == clerk_user_id) ) + user = result.scalar_one_or_none() if user: return user else: diff --git a/app/api/routes/api_auth.py b/app/api/routes/api_auth.py index 6b7c232..4a68748 100644 --- a/app/api/routes/api_auth.py +++ b/app/api/routes/api_auth.py @@ -6,21 +6,21 @@ import requests from fastapi import APIRouter, Depends, HTTPException, Request, status from jose import jwt -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies import ( get_current_active_user, get_current_active_user_from_clerk, ) from app.api.schemas.user import User -from app.core.database import get_db +from app.core.database import get_async_db from app.models.user import User as UserModel router = APIRouter() async def get_user_from_any_auth( - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), jwt_user: UserModel | None = Depends(get_current_active_user), clerk_user: UserModel | None = Depends(get_current_active_user_from_clerk), ) -> UserModel: @@ -45,7 +45,7 @@ async def get_user_from_any_auth( @router.get("/me", response_model=User) -def get_unified_current_user( +async def get_unified_current_user( current_user: UserModel = Depends(get_user_from_any_auth), ) -> Any: """ diff --git a/app/api/routes/api_keys.py b/app/api/routes/api_keys.py index 723a34e..27f1642 100644 --- a/app/api/routes/api_keys.py +++ b/app/api/routes/api_keys.py @@ -1,7 +1,9 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.api.dependencies import ( get_current_active_user, @@ -14,7 +16,7 @@ ForgeApiKeyUpdate, ) from app.core.async_cache import invalidate_forge_scope_cache_async, invalidate_user_cache_async, invalidate_provider_service_cache_async -from app.core.database import get_db +from app.core.database import get_async_db from app.core.security import generate_forge_api_key from app.models.forge_api_key import ForgeApiKey from app.models.provider_key import ProviderKey as ProviderKeyModel @@ -26,15 +28,17 @@ async def _get_api_keys_internal( - db: Session, current_user: UserModel + db: AsyncSession, current_user: UserModel ) -> list[ForgeApiKeyMasked]: """ Internal logic to get all API keys for the current user. """ - api_keys_query = db.query(ForgeApiKey).filter( - ForgeApiKey.user_id == current_user.id + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) + .filter(ForgeApiKey.user_id == current_user.id) ) - api_keys = api_keys_query.all() + api_keys = result.scalars().all() masked_keys = [] for api_key_db in api_keys: @@ -48,7 +52,7 @@ async def _get_api_keys_internal( async def _create_api_key_internal( - api_key_create: ForgeApiKeyCreate, db: Session, current_user: UserModel + api_key_create: ForgeApiKeyCreate, db: AsyncSession, current_user: UserModel ) -> ForgeApiKeyResponse: """ Internal logic to create a new API key for the current user. @@ -63,14 +67,13 @@ async def _create_api_key_internal( if api_key_create.allowed_provider_key_ids is not None: allowed_providers = [] if api_key_create.allowed_provider_key_ids: - allowed_providers = ( - db.query(ProviderKeyModel) - .filter( + result = await db.execute( + select(ProviderKeyModel).filter( ProviderKeyModel.id.in_(api_key_create.allowed_provider_key_ids), ProviderKeyModel.user_id == current_user.id, ) - .all() ) + allowed_providers = result.scalars().all() if len(allowed_providers) != len( set(api_key_create.allowed_provider_key_ids) ): @@ -81,8 +84,8 @@ async def _create_api_key_internal( db_api_key.allowed_provider_keys = allowed_providers db.add(db_api_key) - db.commit() - db.refresh(db_api_key) + await db.commit() + await db.refresh(db_api_key) response_data = db_api_key.__dict__.copy() response_data["allowed_provider_key_ids"] = [ @@ -92,16 +95,18 @@ async def _create_api_key_internal( async def _update_api_key_internal( - key_id: int, api_key_update: ForgeApiKeyUpdate, db: Session, current_user: UserModel + key_id: int, api_key_update: ForgeApiKeyUpdate, db: AsyncSession, current_user: UserModel ) -> ForgeApiKeyResponse: """ Internal logic to update an API key for the current user. """ - db_api_key = ( - db.query(ForgeApiKey) + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id) - .first() ) + db_api_key = result.scalar_one_or_none() + if not db_api_key: raise HTTPException(status_code=404, detail="API key not found") @@ -117,14 +122,13 @@ async def _update_api_key_internal( if api_key_update.allowed_provider_key_ids is not None: db_api_key.allowed_provider_keys.clear() if api_key_update.allowed_provider_key_ids: - allowed_providers = ( - db.query(ProviderKeyModel) - .filter( + result = await db.execute( + select(ProviderKeyModel).filter( ProviderKeyModel.id.in_(api_key_update.allowed_provider_key_ids), ProviderKeyModel.user_id == current_user.id, ) - .all() ) + allowed_providers = result.scalars().all() if len(allowed_providers) != len( set(api_key_update.allowed_provider_key_ids) ): @@ -134,8 +138,8 @@ async def _update_api_key_internal( ) db_api_key.allowed_provider_keys.extend(allowed_providers) - db.commit() - db.refresh(db_api_key) + await db.commit() + await db.refresh(db_api_key) # Invalidate forge scope cache if the scope was updated if api_key_update.allowed_provider_key_ids is not None: @@ -149,16 +153,18 @@ async def _update_api_key_internal( async def _delete_api_key_internal( - key_id: int, db: Session, current_user: UserModel + key_id: int, db: AsyncSession, current_user: UserModel ) -> ForgeApiKeyResponse: """ Internal logic to delete an API key for the current user. """ - db_api_key = ( - db.query(ForgeApiKey) + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id) - .first() ) + db_api_key = result.scalar_one_or_none() + if not db_api_key: raise HTTPException(status_code=404, detail="API key not found") @@ -174,8 +180,8 @@ async def _delete_api_key_internal( "allowed_provider_key_ids": [pk.id for pk in db_api_key.allowed_provider_keys], } - db.delete(db_api_key) - db.commit() + await db.delete(db_api_key) + await db.commit() await invalidate_user_cache_async(key_to_invalidate) await invalidate_forge_scope_cache_async(key_to_invalidate) @@ -184,16 +190,18 @@ async def _delete_api_key_internal( async def _regenerate_api_key_internal( - key_id: int, db: Session, current_user: UserModel + key_id: int, db: AsyncSession, current_user: UserModel ) -> ForgeApiKeyResponse: """ Internal logic to regenerate an API key for the current user while preserving other settings. """ - db_api_key = ( - db.query(ForgeApiKey) + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id) - .first() ) + db_api_key = result.scalar_one_or_none() + if not db_api_key: raise HTTPException(status_code=404, detail="API key not found") @@ -207,8 +215,8 @@ async def _regenerate_api_key_internal( new_key_value = generate_forge_api_key() db_api_key.key = new_key_value - db.commit() - db.refresh(db_api_key) + await db.commit() + await db.refresh(db_api_key) response_data = db_api_key.__dict__.copy() response_data["allowed_provider_key_ids"] = [ @@ -222,7 +230,7 @@ async def _regenerate_api_key_internal( @router.get("/", response_model=list[ForgeApiKeyMasked]) async def get_api_keys( - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: return await _get_api_keys_internal(db, current_user) @@ -231,7 +239,7 @@ async def get_api_keys( @router.post("/", response_model=ForgeApiKeyResponse) async def create_api_key( api_key_create: ForgeApiKeyCreate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: return await _create_api_key_internal(api_key_create, db, current_user) @@ -241,7 +249,7 @@ async def create_api_key( async def update_api_key( key_id: int, api_key_update: ForgeApiKeyUpdate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: return await _update_api_key_internal(key_id, api_key_update, db, current_user) @@ -250,7 +258,7 @@ async def update_api_key( @router.delete("/{key_id}", response_model=ForgeApiKeyResponse) async def delete_api_key( key_id: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: return await _delete_api_key_internal(key_id, db, current_user) @@ -259,7 +267,7 @@ async def delete_api_key( @router.post("/{key_id}/regenerate", response_model=ForgeApiKeyResponse) async def regenerate_api_key( key_id: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: return await _regenerate_api_key_internal(key_id, db, current_user) @@ -268,7 +276,7 @@ async def regenerate_api_key( # Clerk versions of the routes @router.get("/clerk", response_model=list[ForgeApiKeyMasked]) async def get_api_keys_clerk( - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: return await _get_api_keys_internal(db, current_user) @@ -277,7 +285,7 @@ async def get_api_keys_clerk( @router.post("/clerk", response_model=ForgeApiKeyResponse) async def create_api_key_clerk( api_key_create: ForgeApiKeyCreate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: return await _create_api_key_internal(api_key_create, db, current_user) @@ -287,7 +295,7 @@ async def create_api_key_clerk( async def update_api_key_clerk( key_id: int, api_key_update: ForgeApiKeyUpdate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: return await _update_api_key_internal(key_id, api_key_update, db, current_user) @@ -296,7 +304,7 @@ async def update_api_key_clerk( @router.delete("/clerk/{key_id}", response_model=ForgeApiKeyResponse) async def delete_api_key_clerk( key_id: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: return await _delete_api_key_internal(key_id, db, current_user) @@ -305,7 +313,7 @@ async def delete_api_key_clerk( @router.post("/clerk/{key_id}/regenerate", response_model=ForgeApiKeyResponse) async def regenerate_api_key_clerk( key_id: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: return await _regenerate_api_key_internal(key_id, db, current_user) diff --git a/app/api/routes/auth.py b/app/api/routes/auth.py index df80cd5..c81ad21 100644 --- a/app/api/routes/auth.py +++ b/app/api/routes/auth.py @@ -3,11 +3,12 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from app.api.routes.users import create_user as create_user_endpoint_logic from app.api.schemas.user import Token, User, UserCreate -from app.core.database import get_db +from app.core.database import get_async_db from app.core.logger import get_logger from app.core.security import ( ACCESS_TOKEN_EXPIRE_MINUTES, @@ -22,9 +23,9 @@ @router.post("/register", response_model=User) -def register( +async def register( user_in: UserCreate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: """ Register new user. This will create the user but will not automatically create a Forge API key. @@ -33,7 +34,7 @@ def register( # Call the user creation logic from users.py # This handles checks for existing email/username and password hashing. try: - db_user = create_user_endpoint_logic(user_in=user_in, db=db) + db_user = await create_user_endpoint_logic(user_in=user_in, db=db) except HTTPException as e: # Propagate HTTPExceptions (like 400 for existing user) raise e except Exception as e: # Catch any other unexpected errors during user creation @@ -73,13 +74,18 @@ def register( @router.post("/token", response_model=Token) -def login_for_access_token( - db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends() +async def login_for_access_token( + db: AsyncSession = Depends(get_async_db), + form_data: OAuth2PasswordRequestForm = Depends() ) -> Any: """ Get an access token for future API requests. """ - user = db.query(UserModel).filter(UserModel.username == form_data.username).first() + result = await db.execute( + select(UserModel).filter(UserModel.username == form_data.username) + ) + user = result.scalar_one_or_none() + if not user or not verify_password(form_data.password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/app/api/routes/provider_keys.py b/app/api/routes/provider_keys.py index 1aff50d..523fd1a 100644 --- a/app/api/routes/provider_keys.py +++ b/app/api/routes/provider_keys.py @@ -2,7 +2,8 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from starlette import status from app.api.dependencies import ( @@ -15,13 +16,14 @@ ProviderKeyUpdate, ProviderKeyUpsertItem, ) -from app.core.cache import invalidate_provider_service_cache -from app.core.database import get_db +from app.core.async_cache import invalidate_provider_service_cache_async +from app.core.database import get_async_db from app.core.logger import get_logger from app.core.security import decrypt_api_key, encrypt_api_key from app.models.provider_key import ProviderKey as ProviderKeyModel -from app.models.user import User +from app.models.user import User as UserModel from app.services.providers.adapter_factory import ProviderAdapterFactory +from app.services.providers.base import ProviderAdapter logger = get_logger(name="provider_keys") @@ -29,247 +31,280 @@ # --- Internal Service Functions --- +def _validate_provider_cls_init(provider_name: str, base_url: str, config: dict[str, Any]) -> ProviderAdapter: + provider_cls = ProviderAdapterFactory.get_adapter_cls(provider_name) + try: + provider_cls(provider_name, base_url, config=config) + except Exception as e: + logger.error({ + "message": f"Error initializing provider {provider_name}", + "extra":{ + "error": str(e), + } + }) + raise HTTPException( + status_code=400, + detail=f"Error initializing provider {provider_name}", + ) + return provider_cls + -def _get_provider_keys_internal( - db: Session, current_user: User -) -> list[ProviderKeyModel]: - """Internal. Retrieve all provider keys for the current user.""" - return ( - db.query(ProviderKeyModel) - .filter(ProviderKeyModel.user_id == current_user.id) - .all() +async def _get_provider_keys_internal( + db: AsyncSession, current_user: UserModel +) -> list[ProviderKey]: + """ + Internal logic to get all provider keys for the current user. + """ + result = await db.execute( + select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id) ) + provider_keys = result.scalars().all() + return [ProviderKey.model_validate(pk) for pk in provider_keys] -def _create_provider_key_internal( - provider_key_in: ProviderKeyCreate, db: Session, current_user: User +async def _process_provider_key_create_data( + db: AsyncSession, + provider_key_create: ProviderKeyCreate, + user_id: int, ) -> ProviderKeyModel: - """Internal. Create a new provider key.""" - existing_key = ( - db.query(ProviderKeyModel) - .filter( + provider_name = provider_key_create.provider_name + provider_cls = _validate_provider_cls_init(provider_name, provider_key_create.base_url, provider_key_create.config) + serialized_api_key_config = provider_cls.serialize_api_key_config(provider_key_create.api_key, provider_key_create.config) + + encrypted_key = encrypt_api_key(serialized_api_key_config) + db_provider_key = ProviderKeyModel( + user_id=user_id, + provider_name=provider_name, + encrypted_api_key=encrypted_key, + base_url=provider_key_create.base_url, + model_mapping=provider_key_create.model_mapping, + ) + db.add(db_provider_key) + return db_provider_key + + +async def _create_provider_key_internal( + provider_key_create: ProviderKeyCreate, db: AsyncSession, current_user: UserModel +) -> ProviderKey: + """ + Internal logic to create a new provider key for the current user. + """ + # Check if provider already exists for user + result = await db.execute( + select(ProviderKeyModel).filter( ProviderKeyModel.user_id == current_user.id, - ProviderKeyModel.provider_name == provider_key_in.provider_name, + ProviderKeyModel.provider_name == provider_key_create.provider_name, ) - .first() ) + existing_key = result.scalar_one_or_none() + if existing_key: raise HTTPException( status_code=400, - detail=f"A key for provider {provider_key_in.provider_name} already exists", + detail=f"Provider key for {provider_key_create.provider_name} already exists", ) + + db_provider_key = await _process_provider_key_create_data(db, provider_key_create, current_user.id) + await db.commit() + await db.refresh(db_provider_key) - model_mapping_json = ( - json.dumps(provider_key_in.model_mapping) - if provider_key_in.model_mapping - else None - ) + # Invalidate caches after creating a new provider key + await invalidate_provider_service_cache_async(current_user.id) - provider_name = provider_key_in.provider_name - provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls(provider_name) + return ProviderKey.model_validate(db_provider_key) - # try to initialize the provider adapter - try: - provider_adapter_cls( - provider_name, provider_key_in.base_url, config=provider_key_in.config - ) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Error initializing provider {provider_name}: {e}", - ) - serialized_api_key_config = provider_adapter_cls.serialize_api_key_config( - provider_key_in.api_key, provider_key_in.config - ) +async def _process_provider_key_update_data( + db_provider_key: ProviderKeyModel, + provider_key_update: ProviderKeyUpdate, +) -> ProviderKeyModel: + update_data = provider_key_update.model_dump(exclude_unset=True) + provider_cls = ProviderAdapterFactory.get_adapter_cls(db_provider_key.provider_name) + old_api_key, old_config = provider_cls.deserialize_api_key_config(decrypt_api_key(db_provider_key.encrypted_api_key)) - provider_key = ProviderKeyModel( - provider_name=provider_name, - encrypted_api_key=encrypt_api_key(serialized_api_key_config), - user_id=current_user.id, - base_url=provider_key_in.base_url, - model_mapping=model_mapping_json, - ) - db.add(provider_key) - db.commit() - db.refresh(provider_key) - invalidate_provider_service_cache(current_user.id) - return provider_key + if "api_key" in update_data or "config" in update_data: + api_key = update_data.pop("api_key", None) or old_api_key + config = update_data.pop("config", None) or old_config + _validate_provider_cls_init(db_provider_key.provider_name, db_provider_key.base_url, config) + serialized_api_key_config = provider_cls.serialize_api_key_config(api_key, config) + update_data['encrypted_api_key'] = encrypt_api_key(serialized_api_key_config) + + for field, value in update_data.items(): + setattr(db_provider_key, field, value) + + return db_provider_key -def _update_provider_key_internal( +async def _update_provider_key_internal( provider_name: str, - provider_key_in: ProviderKeyUpdate, - db: Session, - current_user: User, -) -> ProviderKeyModel: - """Internal. Update a provider key.""" - provider_key = ( - db.query(ProviderKeyModel) - .filter( - ProviderKeyModel.user_id == current_user.id, + provider_key_update: ProviderKeyUpdate, + db: AsyncSession, + current_user: UserModel, +) -> ProviderKey: + """ + Internal logic to update a provider key for the current user. + """ + result = await db.execute( + select(ProviderKeyModel).filter( ProviderKeyModel.provider_name == provider_name, + ProviderKeyModel.user_id == current_user.id, ) - .first() ) - if not provider_key: - raise HTTPException( - status_code=404, - detail=f"Provider key for {provider_name} not found", - ) + db_provider_key = result.scalar_one_or_none() + + if not db_provider_key: + raise HTTPException(status_code=404, detail="Provider key not found") + + db_provider_key = await _process_provider_key_update_data(db_provider_key, provider_key_update) - # try to initialize the provider adapter if key info is provided - try: - provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls(provider_name) - _, old_config = provider_adapter_cls.deserialize_api_key_config( - decrypt_api_key(provider_key.encrypted_api_key) - ) - if provider_key_in.api_key or provider_key_in.config: - serialized_api_key_config = provider_adapter_cls.serialize_api_key_config( - provider_key_in.api_key, provider_key_in.config - ) - provider_key.encrypted_api_key = encrypt_api_key(serialized_api_key_config) - if provider_key_in.base_url is not None: - provider_key.base_url = provider_key_in.base_url - if provider_key_in.model_mapping is not None: - provider_key.model_mapping = json.dumps(provider_key_in.model_mapping) - - provider_adapter_cls( - provider_name, - provider_key.base_url, - config=provider_key_in.config or old_config, - ) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Error initializing provider {provider_name}: {e}", - ) + await db.commit() + await db.refresh(db_provider_key) + + # Invalidate caches after updating a provider key + await invalidate_provider_service_cache_async(current_user.id) - db.commit() - db.refresh(provider_key) - invalidate_provider_service_cache(current_user.id) - return provider_key + return ProviderKey.model_validate(db_provider_key) -def _delete_provider_key_internal( - provider_name: str, db: Session, current_user: User +async def _process_provider_key_delete_data( + db: AsyncSession, + provider_name: str, + user_id: int, ) -> ProviderKeyModel: - """Internal. Delete a provider key.""" - provider_key = ( - db.query(ProviderKeyModel) - .filter( - ProviderKeyModel.user_id == current_user.id, + result = await db.execute( + select(ProviderKeyModel).filter( ProviderKeyModel.provider_name == provider_name, + ProviderKeyModel.user_id == user_id, ) - .first() ) - if not provider_key: - raise HTTPException( - status_code=404, - detail=f"Provider key for {provider_name} not found", - ) - db.delete(provider_key) - db.commit() - invalidate_provider_service_cache(current_user.id) - return provider_key + db_provider_key = result.scalar_one_or_none() + + if not db_provider_key: + raise HTTPException(status_code=404, detail="Provider key not found") + + # Store the provider key data before deletion + provider_key_data = ProviderKey.model_validate(db_provider_key) + + await db.delete(db_provider_key) + + return provider_key_data + + +async def _delete_provider_key_internal( + provider_name: str, db: AsyncSession, current_user: UserModel +) -> ProviderKey: + """ + Internal logic to delete a provider key for the current user. + """ + provider_key_data = await _process_provider_key_delete_data(db, provider_name, current_user.id) + await db.commit() + + # Invalidate caches after deleting a provider key + await invalidate_provider_service_cache_async(current_user.id) + + return provider_key_data + +# --- API Endpoints --- @router.get("/", response_model=list[ProviderKey]) -def get_provider_keys( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user), +async def get_provider_keys( + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _get_provider_keys_internal(db, current_user) + return await _get_provider_keys_internal(db, current_user) @router.post("/", response_model=ProviderKey) -def create_provider_key( - provider_key_in: ProviderKeyCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user), +async def create_provider_key( + provider_key_create: ProviderKeyCreate, + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _create_provider_key_internal(provider_key_in, db, current_user) + return await _create_provider_key_internal(provider_key_create, db, current_user) @router.put("/{provider_name}", response_model=ProviderKey) -def update_provider_key( +async def update_provider_key( provider_name: str, - provider_key_in: ProviderKeyUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user), + provider_key_update: ProviderKeyUpdate, + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _update_provider_key_internal( - provider_name, provider_key_in, db, current_user + return await _update_provider_key_internal( + provider_name, provider_key_update, db, current_user ) @router.delete("/{provider_name}", response_model=ProviderKey) -def delete_provider_key( +async def delete_provider_key( provider_name: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user), + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _delete_provider_key_internal(provider_name, db, current_user) + return await _delete_provider_key_internal(provider_name, db, current_user) + + +# --- Clerk API Routes --- -# Clerk versions of the routes @router.get("/clerk", response_model=list[ProviderKey]) -def get_provider_keys_clerk( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user_from_clerk), +async def get_provider_keys_clerk( + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _get_provider_keys_internal(db, current_user) + return await _get_provider_keys_internal(db, current_user) @router.post("/clerk", response_model=ProviderKey) -def create_provider_key_clerk( - provider_key_in: ProviderKeyCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user_from_clerk), +async def create_provider_key_clerk( + provider_key_create: ProviderKeyCreate, + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _create_provider_key_internal(provider_key_in, db, current_user) + return await _create_provider_key_internal(provider_key_create, db, current_user) @router.put("/clerk/{provider_name}", response_model=ProviderKey) -def update_provider_key_clerk( +async def update_provider_key_clerk( provider_name: str, - provider_key_in: ProviderKeyUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user_from_clerk), + provider_key_update: ProviderKeyUpdate, + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _update_provider_key_internal( - provider_name, provider_key_in, db, current_user + return await _update_provider_key_internal( + provider_name, provider_key_update, db, current_user ) @router.delete("/clerk/{provider_name}", response_model=ProviderKey) -def delete_provider_key_clerk( +async def delete_provider_key_clerk( provider_name: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user_from_clerk), + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _delete_provider_key_internal(provider_name, db, current_user) + return await _delete_provider_key_internal(key_id, db, current_user) # --- Batch Upsert API Endpoint --- -def _batch_upsert_provider_keys_internal( +async def _batch_upsert_provider_keys_internal( items: list[ProviderKeyUpsertItem], - db: Session, - current_user: User, -) -> list[ProviderKeyModel]: + db: AsyncSession, + current_user: UserModel, +) -> list[ProviderKey]: """ Internal logic for batch creating or updating provider keys for the current user. """ processed_keys: list[ProviderKeyModel] = [] + processed: bool = False # 1. Fetch all existing keys for the user - existing_keys_query = ( - db.query(ProviderKeyModel) - .filter(ProviderKeyModel.user_id == current_user.id) - .all() + result = await db.execute( + select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id) ) + existing_keys_query = result.scalars().all() # 2. Map them by provider_name for efficient lookup existing_keys_map: dict[str, ProviderKeyModel] = { key.provider_name: key for key in existing_keys_query @@ -278,98 +313,23 @@ def _batch_upsert_provider_keys_internal( for item in items: if "****" in item.api_key: continue + try: + existing_provider_key: ProviderKeyModel | None = existing_keys_map.get(item.provider_name) + # Handle deletion if api_key is "DELETE" if item.api_key == "DELETE": - try: - _delete_provider_key_internal(item.provider_name, db, current_user) - except HTTPException as e: - if ( - e.status_code != status.HTTP_404_NOT_FOUND - ): # Ignore 404 errors for missing keys - raise - continue - - db_key_to_process: ProviderKeyModel | None = existing_keys_map.get( - item.provider_name - ) - - if db_key_to_process: # Update existing key - try: - # try to initialize the provider adapter if key info is provided - provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls( - item.provider_name - ) - _, old_config = provider_adapter_cls.deserialize_api_key_config( - decrypt_api_key(db_key_to_process.encrypted_api_key) - ) - if item.api_key or item.config: - serialized_api_key_config = ( - provider_adapter_cls.serialize_api_key_config( - item.api_key, item.config - ) - ) - db_key_to_process.encrypted_api_key = encrypt_api_key( - serialized_api_key_config - ) - if ( - item.base_url is not None - ): # Allows setting base_url to "" or null - db_key_to_process.base_url = item.base_url - if item.model_mapping is not None: - db_key_to_process.model_mapping = json.dumps(item.model_mapping) - elif ( - hasattr(item, "model_mapping") and item.model_mapping is None - ): # Explicitly clear if None - db_key_to_process.model_mapping = None - provider_adapter_cls( - item.provider_name, - db_key_to_process.base_url, - config=item.config or old_config, - ) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Error updating provider {item.provider_name}: {e}", - ) - # No need to db.add() as it's already tracked by the session + if existing_provider_key: + await _process_provider_key_delete_data(db, item.provider_name, current_user.id) + processed = True + elif existing_provider_key: # Update existing key + db_key_to_process = await _process_provider_key_update_data(existing_provider_key, ProviderKeyUpdate.model_validate(item)) + processed_keys.append(db_key_to_process) + processed = True else: # Create new key - if not item.api_key: - raise HTTPException( - status_code=400, - detail=f"api_key is required to create a new provider key for {item.provider_name}", - ) - model_mapping_json = ( - json.dumps(item.model_mapping) if item.model_mapping else None - ) - provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls( - item.provider_name - ) - # try to initialize the provider adapter - try: - provider_adapter_cls( - item.provider_name, item.base_url, config=item.config - ) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Error initializing provider {item.provider_name}: {e}", - ) - serialized_api_key_config = ( - provider_adapter_cls.serialize_api_key_config( - item.api_key, item.config - ) - ) - db_key_to_process = ProviderKeyModel( - provider_name=item.provider_name, - encrypted_api_key=encrypt_api_key(serialized_api_key_config), - user_id=current_user.id, - base_url=item.base_url, - model_mapping=model_mapping_json, - ) - db.add(db_key_to_process) - - processed_keys.append(db_key_to_process) + db_key_to_process = await _process_provider_key_create_data(db, ProviderKeyCreate.model_validate(item), current_user.id) + processed_keys.append(db_key_to_process) + processed = True except HTTPException as http_exc: # db.rollback() # Optional: rollback if any item fails, or handle partial success @@ -390,16 +350,15 @@ def _batch_upsert_provider_keys_internal( detail=f"An unexpected error occurred while processing '{item.provider_name}'.", ) - if processed_keys: + if processed: try: - db.commit() + await db.commit() for key in processed_keys: - db.refresh( - key - ) # Refresh each key to get DB-generated values like id, timestamps - invalidate_provider_service_cache(current_user.id) + await db.refresh(key) # Refresh each key to get DB-generated values like id, timestamps + processed_keys = [ProviderKey.model_validate(key) for key in processed_keys] + await invalidate_provider_service_cache_async(current_user.id) except Exception as e: - db.rollback() + await db.rollback() error_message_prefix = "Error during final commit/refresh in batch upsert" if hasattr(current_user, "email"): # Check if it's a full User object error_message_prefix += f" (User: {current_user.email})" @@ -412,24 +371,24 @@ def _batch_upsert_provider_keys_internal( @router.post("/batch-upsert", response_model=list[ProviderKey]) -def batch_upsert_provider_keys( +async def batch_upsert_provider_keys( items: list[ProviderKeyUpsertItem], - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user), + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user), ) -> Any: """ Batch create or update provider keys for the current user. """ - return _batch_upsert_provider_keys_internal(items, db, current_user) + return await _batch_upsert_provider_keys_internal(items, db, current_user) @router.post("/clerk/batch-upsert", response_model=list[ProviderKey]) -def batch_upsert_provider_keys_clerk( +async def batch_upsert_provider_keys_clerk( items: list[ProviderKeyUpsertItem], - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user_from_clerk), + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: """ Batch create or update provider keys for the current user (Clerk authenticated). """ - return _batch_upsert_provider_keys_internal(items, db, current_user) + return await _batch_upsert_provider_keys_internal(items, db, current_user) diff --git a/app/api/routes/proxy.py b/app/api/routes/proxy.py index 669973a..82384b6 100644 --- a/app/api/routes/proxy.py +++ b/app/api/routes/proxy.py @@ -2,7 +2,9 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException, Request -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from starlette.responses import StreamingResponse from app.api.dependencies import get_user_by_api_key @@ -15,8 +17,9 @@ ImageGenerationRequest, ) from app.core.async_cache import async_provider_service_cache -from app.core.database import get_db +from app.core.database import get_async_db from app.core.logger import get_logger +from app.models.forge_api_key import ForgeApiKey from app.models.user import User from app.services.provider_service import ProviderService @@ -29,7 +32,7 @@ # None → unrestricted, [] → explicitly no providers. # ------------------------------------------------------------- async def _get_allowed_provider_names( - request: Request, db: Session + request: Request, db: AsyncSession ) -> list[str] | None: api_key = getattr(request.state, "forge_api_key", None) if api_key is None: @@ -43,19 +46,15 @@ async def _get_allowed_provider_names( if allowed is not None: return allowed - from sqlalchemy.orm import joinedload - - from app.models.forge_api_key import ForgeApiKey - allowed = await async_provider_service_cache.get(f"forge_scope:{api_key}") if allowed is None: - forge_key = ( - db.query(ForgeApiKey) - .options(joinedload(ForgeApiKey.allowed_provider_keys)) + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) .filter(ForgeApiKey.key == f"forge-{api_key}", ForgeApiKey.is_active) - .first() ) + forge_key = result.scalar_one_or_none() if forge_key is None: raise HTTPException( status_code=401, detail="Forge API key not found or inactive" @@ -74,7 +73,7 @@ async def create_chat_completion( request: Request, chat_request: ChatCompletionRequest, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: """ Create a chat completion (OpenAI-compatible endpoint). @@ -123,7 +122,7 @@ async def create_completion( request: Request, completion_request: CompletionRequest, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: """ Create a completion (OpenAI-compatible endpoint). @@ -166,7 +165,7 @@ async def create_image_generation( request: Request, image_generation_request: ImageGenerationRequest, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: """ Create an image generation (OpenAI-compatible endpoint). @@ -197,7 +196,7 @@ async def create_image_edits( request: Request, image_edits_request: ImageEditsRequest, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: try: provider_service = await ProviderService.async_get_instance(user, db) @@ -221,7 +220,7 @@ async def create_image_edits( async def list_models( request: Request, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> dict[str, Any]: """ List available models. Only models from providers that are within the scope of the @@ -237,6 +236,7 @@ async def list_models( ) return {"object": "list", "data": models} except Exception as err: + logger.error(f"Error listing models: {str(err)}") raise HTTPException( status_code=500, detail=f"Error listing models: {str(err)}" ) from err @@ -248,7 +248,7 @@ async def create_embeddings( request: Request, embeddings_request: EmbeddingsRequest, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: """ Create embeddings (OpenAI-compatible endpoint). diff --git a/app/api/routes/stats.py b/app/api/routes/stats.py index 7aa39be..a470e07 100644 --- a/app/api/routes/stats.py +++ b/app/api/routes/stats.py @@ -2,12 +2,12 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies import ( get_current_active_user_from_clerk, get_current_user, - get_db, + get_async_db, get_user_by_api_key, ) from app.models.user import User @@ -29,7 +29,7 @@ async def get_user_stats( end_date: date | None = Query( None, description="End date for filtering (YYYY-MM-DD)" ), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ): """ Get aggregated usage statistics for the current user, queried from request logs. @@ -38,7 +38,7 @@ async def get_user_stats( """ # Note: Service layer now handles aggregation and filtering # We pass the query parameters directly to the service method - stats = UsageStatsService.get_user_stats( + stats = await UsageStatsService.get_user_stats( db=db, user_id=current_user.id, provider=provider, @@ -62,7 +62,7 @@ async def get_user_stats_clerk( end_date: date | None = Query( None, description="End date for filtering (YYYY-MM-DD)" ), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ): """ Get aggregated usage statistics for the current user, queried from request logs. @@ -71,7 +71,7 @@ async def get_user_stats_clerk( """ # Note: Service layer now handles aggregation and filtering # We pass the query parameters directly to the service method - stats = UsageStatsService.get_user_stats( + stats = await UsageStatsService.get_user_stats( db=db, user_id=current_user.id, provider=provider, @@ -93,7 +93,7 @@ async def get_all_stats( end_date: date | None = Query( None, description="End date for filtering (YYYY-MM-DD)" ), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ): """ Get aggregated usage statistics for all users, queried from request logs. @@ -106,7 +106,7 @@ async def get_all_stats( status_code=403, detail="Not authorized to access admin statistics" ) - stats = UsageStatsService.get_all_stats( + stats = await UsageStatsService.get_all_stats( db=db, provider=provider, model=model, start_date=start_date, end_date=end_date ) return stats diff --git a/app/api/routes/users.py b/app/api/routes/users.py index 9abdce9..52d84ae 100644 --- a/app/api/routes/users.py +++ b/app/api/routes/users.py @@ -1,17 +1,14 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.orm import Session - -from app.api.dependencies import ( - get_current_active_user, - get_current_active_user_from_clerk, -) -from app.api.schemas.user import MaskedUser, User, UserCreate, UserUpdate -from app.core.cache import invalidate_user_cache -from app.core.database import get_db -from app.core.logger import get_logger +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.dependencies import get_current_active_user, get_current_active_user_from_clerk +from app.api.schemas.user import User, UserCreate, UserUpdate, MaskedUser +from app.core.database import get_async_db from app.core.security import get_password_hash +from app.core.logger import get_logger from app.models.user import User as UserModel from app.services.provider_service import create_default_tensorblock_provider_for_user @@ -20,53 +17,61 @@ router = APIRouter() -@router.post("/", response_model=User, status_code=201) -def create_user( - user_in: UserCreate, db: Session = Depends(get_db) -) -> Any: # pragma: no cover - (Covered by test_user_creation_and_login) +@router.post("/", response_model=User) +async def create_user( + user_in: UserCreate, db: AsyncSession = Depends(get_async_db) +) -> Any: """ - Create new user. + Create a new user. """ - db_user = db.query(UserModel).filter(UserModel.email == user_in.email).first() + # Check if email already exists + result = await db.execute( + select(UserModel).filter(UserModel.email == user_in.email) + ) + db_user = result.scalar_one_or_none() if db_user: raise HTTPException( - status_code=400, - detail="The user with this email already exists in the system.", + status_code=400, detail="Email already registered" ) - db_user = db.query(UserModel).filter(UserModel.username == user_in.username).first() + + # Check if username already exists + result = await db.execute( + select(UserModel).filter(UserModel.username == user_in.username) + ) + db_user = result.scalar_one_or_none() if db_user: raise HTTPException( - status_code=400, - detail="The user with this username already exists in the system.", + status_code=400, detail="Username already registered" ) - + + # Create new user hashed_password = get_password_hash(user_in.password) db_user = UserModel( - username=user_in.username, email=user_in.email, + username=user_in.username, hashed_password=hashed_password, - # Removed automatic API key generation on user creation - # api_key=generate_forge_api_key(), # Users will create keys via /api-keys endpoint - is_active=True, # Default to active, admin can deactivate ) db.add(db_user) - db.commit() - db.refresh(db_user) + await db.commit() + await db.refresh(db_user) # Create default TensorBlock provider for the new user try: - create_default_tensorblock_provider_for_user(db_user.id, db) + await create_default_tensorblock_provider_for_user(db_user.id, db) except Exception as e: # Log error but don't fail user creation - logger.warning( - f"Failed to create default TensorBlock provider for user {db_user.id}: {e}" - ) + logger.error({ + "message": f"Error creating default TensorBlock provider for user {db_user.id}", + "extra": { + "error": str(e), + } + }) return db_user @router.get("/me", response_model=MaskedUser) -def read_user_me( +async def read_user_me( current_user: UserModel = Depends(get_current_active_user), ) -> Any: """ @@ -85,7 +90,7 @@ def read_user_me( @router.get("/me/clerk", response_model=MaskedUser) -def read_user_me_clerk( +async def read_user_me_clerk( current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: """ @@ -103,9 +108,9 @@ def read_user_me_clerk( @router.put("/me", response_model=User) -def update_user_me( +async def update_user_me( user_in: UserUpdate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: """ @@ -119,8 +124,8 @@ def update_user_me( current_user.hashed_password = get_password_hash(user_in.password) db.add(current_user) - db.commit() - db.refresh(current_user) + await db.commit() + await db.refresh(current_user) invalidate_user_cache( current_user.id ) # Assuming user_id is the cache key for user object @@ -131,15 +136,15 @@ def update_user_me( @router.put("/me/clerk", response_model=User) -def update_user_me_clerk( +async def update_user_me_clerk( user_in: UserUpdate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: """ Update current user from Clerk. """ - return update_user_me(user_in, db, current_user) + return await update_user_me(user_in, db, current_user) # The regenerate_api_key and regenerate_api_key_clerk endpoints have been removed. diff --git a/app/api/routes/webhooks.py b/app/api/routes/webhooks.py index ccfaa8a..2dac064 100644 --- a/app/api/routes/webhooks.py +++ b/app/api/routes/webhooks.py @@ -1,12 +1,14 @@ import json import os +from typing import Any from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session -from svix.webhooks import Webhook, WebhookVerificationError +from sqlalchemy.ext.asyncio import AsyncSession +from svix import Webhook, WebhookVerificationError -from app.core.database import get_db +from app.core.database import get_async_db from app.core.logger import get_logger from app.core.security import generate_forge_api_key from app.models.user import User @@ -21,7 +23,7 @@ @router.post("/clerk") -async def clerk_webhook_handler(request: Request, db: Session = Depends(get_db)): +async def clerk_webhook_handler(request: Request, db: AsyncSession = Depends(get_async_db)): """ Handle Clerk webhooks for user events. @@ -99,100 +101,13 @@ async def clerk_webhook_handler(request: Request, db: Session = Depends(get_db)) # Handle different event types if event_type == "user.created": - # Check if user already exists - user = db.query(User).filter(User.clerk_user_id == clerk_user_id).first() - if user: - return {"status": "success", "message": "User already exists"} - - # Check if user exists with this email - existing_user = db.query(User).filter(User.email == email).first() - if existing_user: - # Link existing user to Clerk ID - try: - existing_user.clerk_user_id = clerk_user_id - db.commit() - return {"status": "success", "message": "Linked to existing user"} - except IntegrityError: - # Another request might have already linked this user or created a new one - db.rollback() - # Retry the query to get the user - user = ( - db.query(User) - .filter(User.clerk_user_id == clerk_user_id) - .first() - ) - if user: - return {"status": "success", "message": "User already exists"} - # If still no user, continue with creation attempt - - # Create new user - forge_api_key = generate_forge_api_key() - - try: - user = User( - email=email, - username=username, - clerk_user_id=clerk_user_id, - is_active=True, - forge_api_key=forge_api_key, - ) - db.add(user) - db.commit() - - # Create default TensorBlock provider for the new user - try: - create_default_tensorblock_provider_for_user(user.id, db) - except Exception as e: - # Log error but don't fail user creation - logger.warning( - f"Failed to create default TensorBlock provider for user {user.id}: {e}" - ) - - return {"status": "success", "message": "User created"} - except IntegrityError as e: - # Handle race condition: another request might have created the user - db.rollback() - if "users_clerk_user_id_key" in str(e) or "clerk_user_id" in str(e): - # Retry the query to get the user that was created by another request - user = ( - db.query(User) - .filter(User.clerk_user_id == clerk_user_id) - .first() - ) - if user: - return {"status": "success", "message": "User already exists"} - else: - # This shouldn't happen, but handle it gracefully - return { - "status": "error", - "message": "Failed to create user due to database constraint", - } - else: - # Re-raise other integrity errors - raise + await handle_user_created(event_data, db) elif event_type == "user.updated": - # Update user if they exist - user = db.query(User).filter(User.clerk_user_id == clerk_user_id).first() - if not user: - return {"status": "error", "message": "User not found"} - - # Update fields - if email and user.email != email: - user.email = email - if username and user.username != username: - user.username = username - - db.commit() - return {"status": "success", "message": "User updated"} + await handle_user_updated(event_data, db) elif event_type == "user.deleted": - # Deactivate user rather than delete - user = db.query(User).filter(User.clerk_user_id == clerk_user_id).first() - if user: - user.is_active = False - db.commit() - return {"status": "success", "message": "User deactivated"} + await handle_user_deleted(event_data, db) return {"status": "success", "message": f"Event {event_type} processed"} @@ -205,3 +120,121 @@ async def clerk_webhook_handler(request: Request, db: Session = Depends(get_db)) status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error processing webhook: {str(e)}", ) + + +async def handle_user_created(event_data: dict, db: AsyncSession): + """Handle user.created event from Clerk""" + try: + clerk_user_id = event_data.get("id") + email = event_data.get("email_addresses", [{}])[0].get("email_address", "") + username = ( + event_data.get("username") + or event_data.get("first_name", "") + or email.split("@")[0] + ) + + logger.info(f"Creating user from Clerk webhook: {username} ({email})") + + # Check if user already exists by clerk_user_id + result = await db.execute( + select(User).filter(User.clerk_user_id == clerk_user_id) + ) + user = result.scalar_one_or_none() + if user: + logger.info(f"User {username} already exists with Clerk ID") + return + + # Check if user exists with this email + result = await db.execute( + select(User).filter(User.email == email) + ) + existing_user = result.scalar_one_or_none() + if existing_user: + # Link existing user to Clerk ID + existing_user.clerk_user_id = clerk_user_id + await db.commit() + logger.info(f"Linked existing user {existing_user.username} to Clerk ID") + return + + # Create new user + user = User( + username=username, + email=email, + clerk_user_id=clerk_user_id, + is_active=True, + hashed_password="", # Clerk handles authentication + ) + db.add(user) + await db.commit() + await db.refresh(user) + + # Create default provider for the user + create_default_tensorblock_provider_for_user(user.id, db) + + logger.info(f"Successfully created user {username} with ID {user.id}") + + except Exception as e: + await db.rollback() + logger.error(f"Failed to create user from webhook: {e}", exc_info=True) + raise + + +async def handle_user_updated(event_data: dict, db: AsyncSession): + """Handle user.updated event from Clerk""" + try: + clerk_user_id = event_data.get("id") + email = event_data.get("email_addresses", [{}])[0].get("email_address", "") + username = ( + event_data.get("username") + or event_data.get("first_name", "") + or email.split("@")[0] + ) + + logger.info(f"Updating user from Clerk webhook: {username} ({email})") + + result = await db.execute( + select(User).filter(User.clerk_user_id == clerk_user_id) + ) + user = result.scalar_one_or_none() + if not user: + logger.warning(f"User with Clerk ID {clerk_user_id} not found for update") + return + + # Update user information + user.username = username + user.email = email + await db.commit() + + logger.info(f"Successfully updated user {username}") + + except Exception as e: + await db.rollback() + logger.error(f"Failed to update user from webhook: {e}", exc_info=True) + raise + + +async def handle_user_deleted(event_data: dict, db: AsyncSession): + """Handle user.deleted event from Clerk""" + try: + clerk_user_id = event_data.get("id") + + logger.info(f"Deleting user from Clerk webhook: {clerk_user_id}") + + result = await db.execute( + select(User).filter(User.clerk_user_id == clerk_user_id) + ) + user = result.scalar_one_or_none() + if not user: + logger.warning(f"User with Clerk ID {clerk_user_id} not found for deletion") + return + + # Deactivate user instead of deleting to preserve data integrity + user.is_active = False + await db.commit() + + logger.info(f"Successfully deactivated user {user.username}") + + except Exception as e: + await db.rollback() + logger.error(f"Failed to delete user from webhook: {e}", exc_info=True) + raise diff --git a/app/api/schemas/provider_key.py b/app/api/schemas/provider_key.py index aaaac2f..17dbc57 100644 --- a/app/api/schemas/provider_key.py +++ b/app/api/schemas/provider_key.py @@ -64,6 +64,7 @@ class ProviderKeyCreate(ProviderKeyBase): class ProviderKeyUpdate(BaseModel): api_key: str | None = None + config: dict[str, str] | None = None base_url: str | None = None model_mapping: dict[str, str] | None = None diff --git a/app/core/async_cache.py b/app/core/async_cache.py index a3e34e6..c3ba536 100644 --- a/app/core/async_cache.py +++ b/app/core/async_cache.py @@ -10,7 +10,8 @@ from collections.abc import Callable from typing import Any, TypeVar -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select from app.api.schemas.cached_user import CachedUser from app.core.logger import get_logger @@ -391,7 +392,7 @@ async def invalidate_all_caches_async() -> None: logger.debug("Cache: Invalidated all caches") -async def warm_cache_async(db: Session) -> None: +async def warm_cache_async(db: AsyncSession) -> None: """Pre-cache frequently accessed data asynchronously""" from app.models.user import User from app.services.provider_service import ProviderService @@ -400,14 +401,17 @@ async def warm_cache_async(db: Session) -> None: logger.info("Cache: Starting cache warm-up...") # Cache active users - active_users = db.query(User).filter(User.is_active).all() + result = await db.execute(select(User).filter(User.is_active)) + active_users = result.scalars().all() + for user in active_users: # Get user's Forge API keys - forge_api_keys = ( - db.query(ForgeApiKey) + result = await db.execute( + select(ForgeApiKey) .filter(ForgeApiKey.user_id == user.id, ForgeApiKey.is_active) - .all() ) + forge_api_keys = result.scalars().all() + for key in forge_api_keys: await cache_user_async(key.key, user) diff --git a/app/core/cache.py b/app/core/cache.py index 26f5693..7fe9fbe 100644 --- a/app/core/cache.py +++ b/app/core/cache.py @@ -1,3 +1,4 @@ +# TODO: deprecate this and move to async cache import functools import os import time @@ -352,11 +353,6 @@ async def warm_cache(db: Session) -> None: # Cache user with their Forge API key cache_user(key.key, user) - # Cache provider services for active users - for user in active_users: - service = ProviderService.get_instance(user, db) - cache_provider_service(user.id, service) - if DEBUG_CACHE: logger.info(f"Cache: Warm-up complete. Cached {len(active_users)} users") diff --git a/app/core/database.py b/app/core/database.py index 92cf992..5253f03 100644 --- a/app/core/database.py +++ b/app/core/database.py @@ -1,31 +1,87 @@ import os from dotenv import load_dotenv +from contextlib import asynccontextmanager from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.orm import declarative_base, sessionmaker load_dotenv() +POOL_SIZE = 5 +MAX_OVERFLOW = 10 +MAX_TIMEOUT = 30 +POOL_RECYCLE = 1800 + SQLALCHEMY_DATABASE_URL = os.getenv("DATABASE_URL") if not SQLALCHEMY_DATABASE_URL: raise ValueError("DATABASE_URL environment variable is not set") +# Sync engine and session engine = create_engine( SQLALCHEMY_DATABASE_URL, - pool_size=5, - max_overflow=10, - pool_timeout=30, - pool_recycle=1800, + pool_size=POOL_SIZE, + max_overflow=MAX_OVERFLOW, + pool_timeout=MAX_TIMEOUT, + pool_recycle=POOL_RECYCLE, + echo=False, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base = declarative_base() - -# Dependency +# Sync dependency def get_db(): db = SessionLocal() try: yield db finally: db.close() + + +# Async engine and session (new) +# Convert the DATABASE_URL to async format if it's using psycopg2 +ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL +if SQLALCHEMY_DATABASE_URL.startswith("postgresql://"): + ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://") +elif SQLALCHEMY_DATABASE_URL.startswith("postgresql+psycopg2://"): + ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL.replace("postgresql+psycopg2://", "postgresql+asyncpg://") + +async_engine = create_async_engine( + ASYNC_DATABASE_URL, + pool_size=POOL_SIZE, + max_overflow=MAX_OVERFLOW, + pool_timeout=MAX_TIMEOUT, + pool_recycle=POOL_RECYCLE, + echo=False, +) + +AsyncSessionLocal = async_sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, +) + +Base = declarative_base() + + +# Async dependency +async def get_async_db(): + async with AsyncSessionLocal() as session: + try: + yield session + finally: + await session.close() + + +@asynccontextmanager +async def get_db_session(): + """Async context manager for database sessions""" + async with AsyncSessionLocal() as session: + try: + yield session + except Exception: + await session.rollback() + raise + finally: + await session.close() \ No newline at end of file diff --git a/app/models/provider_key.py b/app/models/provider_key.py index 9fb4b2b..eb191e3 100644 --- a/app/models/provider_key.py +++ b/app/models/provider_key.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy import Column, ForeignKey, Integer, String, JSON from sqlalchemy.orm import relationship from app.models.forge_api_key import forge_api_key_provider_scope_association @@ -18,7 +18,7 @@ class ProviderKey(BaseModel): base_url = Column( String, nullable=True ) # Allow custom base URLs for some providers - model_mapping = Column(String, nullable=True) # JSON string for model name mappings + model_mapping = Column(JSON, nullable=True) # JSON dict for model name mappings # Relationship to ForgeApiKeys that have this provider key in their scope scoped_forge_api_keys = relationship( diff --git a/app/services/provider_service.py b/app/services/provider_service.py index 61688e3..281050f 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -6,13 +6,10 @@ from collections.abc import AsyncGenerator from typing import Any, ClassVar -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession -# For async support -from app.core.cache import ( - DEBUG_CACHE, - provider_service_cache, -) +from app.core.async_cache import async_provider_service_cache, DEBUG_CACHE from app.core.logger import get_logger from app.core.security import decrypt_api_key, encrypt_api_key from app.exceptions.exceptions import InvalidProviderException, BaseInvalidRequestException, InvalidForgeKeyException @@ -53,7 +50,7 @@ class method rather than direct instantiation. # ------------------------------------------------------------------ # Helper for building a cache key that works across all workers. - # Stored via app.core.cache.provider_service_cache which resolves to + # Stored via app.core.async_cache.async_provider_service_cache which resolves to # either RedisCache or in-memory Cache. # ------------------------------------------------------------------ @@ -62,7 +59,7 @@ def _model_cache_key(cls, provider_name: str, cache_key: str) -> str: # Using a stable namespace makes invalidation easier return f"models:{provider_name}:{cache_key}" - def __init__(self, user_id: int, db: Session): + def __init__(self, user_id: int, db: AsyncSession): self.user_id = user_id self.db = db self.provider_keys: dict[str, dict[str, Any]] = {} @@ -71,28 +68,7 @@ def __init__(self, user_id: int, db: Session): self._keys_loaded = False @classmethod - def get_instance(cls, user: User, db: Session) -> "ProviderService": - """Get a cached instance of ProviderService for a user or create a new one""" - cache_key = f"provider_service:{user.id}" - cached_instance = provider_service_cache.get(cache_key) - if cached_instance: - if DEBUG_CACHE: - logger.debug( - f"Using cached ProviderService instance for user {user.id}" - ) - # Update the db session reference for the cached instance - cached_instance.db = db - return cached_instance - - # No cached instance found, create a new one - if DEBUG_CACHE: - logger.debug(f"Creating new ProviderService instance for user {user.id}") - instance = cls(user.id, db) - provider_service_cache.set(cache_key, instance) - return instance - - @classmethod - async def async_get_instance(cls, user: User, db: Session) -> "ProviderService": + async def async_get_instance(cls, user: User, db: AsyncSession) -> "ProviderService": """Get a cached instance of ProviderService for a user or create a new one (async version)""" from app.core.async_cache import async_provider_service_cache @@ -117,7 +93,7 @@ async def async_get_instance(cls, user: User, db: Session) -> "ProviderService": return instance @classmethod - def get_cached_models( + async def get_cached_models( cls, provider_name: str, cache_key: str ) -> list[dict[str, Any]] | None: """Return cached model list if present (shared cache).""" @@ -131,7 +107,7 @@ def get_cached_models( return l1_entry[1] # -------- L2: shared cache (Redis / memory) -------- - models = provider_service_cache.get(key) + models = await async_provider_service_cache.get(key) if models: # populate L1 cls._models_l1_cache[key] = (time.time() + cls._models_cache_ttl, models) @@ -140,14 +116,14 @@ def get_cached_models( return models @classmethod - def cache_models( + async def cache_models( cls, provider_name: str, cache_key: str, models: list[dict[str, Any]] ) -> None: """Store models in the shared cache with a TTL.""" key = cls._model_cache_key(provider_name, cache_key) # Write to shared cache (L2) - provider_service_cache.set(key, models, ttl=cls._models_cache_ttl) + await async_provider_service_cache.set(key, models, ttl=cls._models_cache_ttl) # Populate/refresh L1 cls._models_l1_cache[key] = (time.time() + cls._models_cache_ttl, models) @@ -163,30 +139,14 @@ def _get_adapters(self) -> dict[str, ProviderAdapter]: ProviderService._adapters_cache = ProviderAdapterFactory.get_all_adapters() return ProviderService._adapters_cache - def _parse_model_mapping(self, mapping_str: str | None) -> dict: - if not mapping_str: - return {} - try: - return json.loads(mapping_str) - except json.JSONDecodeError: - logger.warning(f"Failed to parse model_mapping JSON: {mapping_str}") - # Try a literal eval as fallback - try: - import ast - - return ast.literal_eval(mapping_str) - except (SyntaxError, ValueError): - logger.warning(f"Could not parse model_mapping: {mapping_str}") - return {} - - def _load_provider_keys(self) -> dict[str, dict[str, Any]]: + async def _load_provider_keys(self) -> dict[str, dict[str, Any]]: """Load all provider keys for the user synchronously, with lazy loading and caching.""" if self._keys_loaded: return self.provider_keys # Try to get provider keys from cache cache_key = f"provider_keys:{self.user_id}" - cached_keys = provider_service_cache.get(cache_key) + cached_keys = await async_provider_service_cache.get(cache_key) if cached_keys is not None: if DEBUG_CACHE: logger.debug( @@ -204,13 +164,12 @@ def _load_provider_keys(self) -> dict[str, dict[str, Any]]: # Query ProviderKey directly by user_id from app.models.provider_key import ProviderKey - provider_key_records = ( - self.db.query(ProviderKey).filter(ProviderKey.user_id == self.user_id).all() - ) + result = await self.db.execute(select(ProviderKey).filter(ProviderKey.user_id == self.user_id)) + provider_key_records = result.scalars().all() keys = {} for provider_key in provider_key_records: - model_mapping = self._parse_model_mapping(provider_key.model_mapping) + model_mapping = provider_key.model_mapping or {} keys[provider_key.provider_name] = { "api_key": decrypt_api_key(provider_key.encrypted_api_key), @@ -226,7 +185,7 @@ def _load_provider_keys(self) -> dict[str, dict[str, Any]]: logger.debug( f"Caching provider keys for user {self.user_id} (TTL: 3600s) (sync)" ) - provider_service_cache.set(cache_key, keys, ttl=3600) # Cache for 1 hour + await async_provider_service_cache.set(cache_key, keys, ttl=3600) # Cache for 1 hour return keys @@ -257,13 +216,12 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]: # Query ProviderKey directly by user_id from app.models.provider_key import ProviderKey - provider_key_records = ( - self.db.query(ProviderKey).filter(ProviderKey.user_id == self.user_id).all() - ) + result = await self.db.execute(select(ProviderKey).filter(ProviderKey.user_id == self.user_id)) + provider_key_records = result.scalars().all() keys = {} for provider_key in provider_key_records: - model_mapping = self._parse_model_mapping(provider_key.model_mapping) + model_mapping = provider_key.model_mapping or {} keys[provider_key.provider_name] = { "api_key": decrypt_api_key(provider_key.encrypted_api_key), @@ -414,7 +372,7 @@ async def list_models( cache_key = f"{base_url}:{hash(frozenset(provider_data.get('model_mapping', {}).items()))}" # Check if we have cached models for this provider - cached_models = self.get_cached_models(provider_name, cache_key) + cached_models = await self.get_cached_models(provider_name, cache_key) if cached_models: models.extend(cached_models) continue @@ -441,7 +399,7 @@ async def _list_models_helper( for model in model_names ] # Cache the results - self.cache_models(provider_name, cache_key, provider_models) + await self.cache_models(provider_name, cache_key, provider_models) return provider_models except Exception as e: @@ -580,12 +538,9 @@ async def process_request( # Record the usage statistics using the new logging method # Use a fresh DB session for logging, since the original request session # may have been closed by FastAPI after the response was returned. - from sqlalchemy.orm import Session + from app.core.database import get_db_session - from app.core.database import SessionLocal - - new_db_session: Session = SessionLocal() - try: + async with get_db_session() as new_db_session: await UsageStatsService.log_api_request( db=new_db_session, user_id=self.user_id, @@ -595,9 +550,7 @@ async def process_request( input_tokens=input_tokens, output_tokens=output_tokens, ) - finally: - new_db_session.close() - return result + return result else: # For streaming responses, wrap the generator to count tokens async def token_counting_stream() -> AsyncGenerator[bytes, None]: @@ -692,12 +645,9 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: # Use a fresh DB session for logging, since the original request session # may have been closed by FastAPI after the response was returned. - from sqlalchemy.orm import Session - - from app.core.database import SessionLocal + from app.core.database import get_db_session - new_db_session: Session = SessionLocal() - try: + async with get_db_session() as new_db_session: await UsageStatsService.log_api_request( db=new_db_session, user_id=self.user_id, @@ -707,14 +657,12 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: input_tokens=input_tokens, output_tokens=output_tokens, ) - finally: - new_db_session.close() # End of token_counting_stream function return token_counting_stream() -def create_default_tensorblock_provider_for_user(user_id: int, db: Session) -> None: +async def create_default_tensorblock_provider_for_user(user_id: int, db: AsyncSession) -> None: """ Create a default TensorBlock provider key for a new user. This allows users to use Forge immediately without binding their own API keys. @@ -756,12 +704,12 @@ def create_default_tensorblock_provider_for_user(user_id: int, db: Session) -> N ) db.add(provider_key) - db.commit() + await db.commit() logger.info(f"Created default TensorBlock provider for user {user_id}") except Exception as e: - db.rollback() + await db.rollback() logger.error( "Error creating default TensorBlock provider for user {}: {}", user_id, diff --git a/app/services/usage_stats_service.py b/app/services/usage_stats_service.py index 4aec1f2..807f32a 100644 --- a/app/services/usage_stats_service.py +++ b/app/services/usage_stats_service.py @@ -2,7 +2,7 @@ from typing import Any from sqlalchemy import func, select -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.core.logger import get_logger from app.models.api_request_log import ApiRequestLog @@ -15,7 +15,7 @@ class UsageStatsService: @staticmethod async def log_api_request( - db: Session, + db: AsyncSession, user_id: int | None, provider_name: str, model: str, @@ -39,19 +39,19 @@ async def log_api_request( cost=cost, ) db.add(log_entry) - db.commit() + await db.commit() logger.debug( f"Logged API request for user {user_id}: {provider_name}/{model}/{endpoint}" ) except Exception as e: - db.rollback() + await db.rollback() logger.error( f"Failed to log API request for user {user_id}: {e}", exc_info=True ) @staticmethod - def get_user_stats( - db: Session, + async def get_user_stats( + db: AsyncSession, user_id: int, provider: str | None = None, model: str | None = None, @@ -89,7 +89,7 @@ def get_user_stats( query = query.group_by(ApiRequestLog.provider_name, ApiRequestLog.model) - results = db.execute(query).fetchall() + results = await db.execute(query) return [ { @@ -105,8 +105,8 @@ def get_user_stats( ] @staticmethod - def get_all_stats( - db: Session, + async def get_all_stats( + db: AsyncSession, provider: str | None = None, # Add provider filter model: str | None = None, # Add model filter start_date: date | None = None, # Add start_date filter @@ -148,7 +148,7 @@ def get_all_stats( query = query.group_by(ApiRequestLog.provider_name, ApiRequestLog.model) # Execute query - results = db.execute(query).fetchall() + results = await db.execute(query) # Convert results to dictionaries return [ diff --git a/forge-cli.py b/forge-cli.py index 1947449..05f1c1b 100755 --- a/forge-cli.py +++ b/forge-cli.py @@ -219,6 +219,38 @@ def list_provider_keys(self) -> list[dict[str, Any]]: except Exception as e: print(f"❌ Error listing provider keys: {str(e)}") return [] + + def update_provider_key(self, provider_name: str, api_key: str | None = None, base_url: str | None = None, model_mapping: str | None = None, config: str | None = None) -> bool: + """Update a provider key""" + if not self.token: + print("❌ Not authenticated. Please login first.") + return False + + url = f"{self.api_url}/provider-keys/{provider_name}" + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json", + } + data = { + "api_key": api_key, + "base_url": base_url, + "model_mapping": json.loads(model_mapping) if model_mapping else None, + "config": json.loads(config) if config else None, + } + try: + response = requests.put(url, headers=headers, json=data) + + if response.status_code == HTTPStatus.OK: + print(f"✅ Successfully updated {provider_name} API key!") + return True + else: + print(f"❌ Error updated provider key: {response.status_code}") + print(response.text) + return False + except Exception as e: + print(f"❌ Error updating provider key: {str(e)}") + return False + def delete_provider_key(self, provider_name: str) -> bool: """Delete a provider key""" @@ -233,7 +265,7 @@ def delete_provider_key(self, provider_name: str) -> bool: response = requests.delete(url, headers=headers) if response.status_code == HTTPStatus.OK: - print(f"✅ Successfully deleted {provider_name} API key!") + print(f"✅ Successfully deleted provider key {provider_name}!") return True else: print(f"❌ Error deleting provider key: {response.status_code}") @@ -462,8 +494,9 @@ def main(): print("8. Add Provider Key") print("9. List Provider Keys") print("10. Delete Provider Key") - print("11. Test Chat Completion") - print("12. List Models") + print("11. Update Provider Key") + print("12. Test Chat Completion") + print("13. List Models") print("0. Exit") choice = input("\nEnter your choice (0-12): ") @@ -614,7 +647,7 @@ def main(): key = getpass("Enter provider API key: ") base_url = input("Enter provider base URL (optional, press Enter to skip): ") config = input("Enter provider config in json string format (optional, press Enter to skip): ") - model_mapping = input("Enter model ampping config in json string format (optional, press Enter to skip): ") + model_mapping = input("Enter model mapping config in json string format (optional, press Enter to skip): ") forge.add_provider_key(provider, key, base_url=base_url, config=config, model_mapping=model_mapping) elif choice == "9": @@ -624,10 +657,24 @@ def main(): forge.list_provider_keys() elif choice == "10": - provider = input("Enter provider name to delete: ") - forge.delete_provider_key(provider) + if not forge.token: + token = input("Enter JWT token: ") + forge.token = token + provider_name = input("Enter provider name to delete: ") + forge.delete_provider_key(provider_name) elif choice == "11": + if not forge.token: + token = input("Enter JWT token: ") + forge.token = token + provider_name = input("Enter provider name to update: ") + api_key = getpass("Enter provider API key: ") + base_url = input("Enter provider base URL (optional, press Enter to skip): ") + config = input("Enter provider config in json string format (optional, press Enter to skip): ") + model_mapping = input("Enter model mapping config in json string format (optional, press Enter to skip): ") + forge.update_provider_key(provider_name, api_key, base_url=base_url, config=config, model_mapping=model_mapping) + + elif choice == "12": model = input("Enter model ID: ") message = input("Enter message: ") api_key = input("Enter your Forge API key: ").strip() @@ -642,7 +689,7 @@ def main(): continue forge.test_chat_completion(model, message, api_key) - elif choice == "12": + elif choice == "13": api_key = input( "Enter your Forge API key (or press Enter to use stored key if available): " ).strip() diff --git a/pyproject.toml b/pyproject.toml index 0471c82..5075668 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "python-jose>=3.3.0", "passlib>=1.7.4", "python-multipart>=0.0.5", - "sqlalchemy>=2.0.0", + "sqlalchemy[asyncio]>=2.0.0", "alembic>=1.10.4", "aiohttp>=3.8.4", "cryptography>=40.0.0", diff --git a/tests/cache/test_async_cache.py b/tests/cache/test_async_cache.py index 115693e..ac32cd7 100755 --- a/tests/cache/test_async_cache.py +++ b/tests/cache/test_async_cache.py @@ -11,6 +11,8 @@ from pathlib import Path from unittest.mock import MagicMock +from sqlalchemy import delete + from app.core.async_cache import ( AsyncCache, async_cached, @@ -188,7 +190,7 @@ async def test_model_list_async_cache(): ProviderService.cache_models(provider_name, cache_key, mock_models) # Test retrieving from cache - cached_models = ProviderService.get_cached_models(provider_name, cache_key) + cached_models = await ProviderService.get_cached_models(provider_name, cache_key) assert cached_models is not None, "Async model list cache get failed" assert len(cached_models) == EXPECTED_MODEL_COUNT, "Model list length mismatch" assert cached_models[FIRST_MODEL_INDEX]["id"] == "gpt-4", "Model ID mismatch" @@ -197,7 +199,7 @@ async def test_model_list_async_cache(): # Test cache invalidation ProviderService._models_cache = {} ProviderService._models_cache_expiry = {} - cached_models = ProviderService.get_cached_models(provider_name, cache_key) + cached_models = await ProviderService.get_cached_models(provider_name, cache_key) assert cached_models is None, "Async model list cache invalidation failed" print("✅ Async model list cache test passed") @@ -354,14 +356,14 @@ async def test_async_cache_invalidation(): model_cache_key = "default" # Set model list in cache - ProviderService.cache_models(provider_name, model_cache_key, mock_models) - cached_models = ProviderService.get_cached_models(provider_name, model_cache_key) + await ProviderService.cache_models(provider_name, model_cache_key, mock_models) + cached_models = await ProviderService.get_cached_models(provider_name, model_cache_key) assert cached_models is not None, "Model list cache set failed" # Invalidate model list cache ProviderService._models_cache = {} ProviderService._models_cache_expiry = {} - cached_models = ProviderService.get_cached_models(provider_name, model_cache_key) + cached_models = await ProviderService.get_cached_models(provider_name, model_cache_key) assert cached_models is None, "Model list cache invalidation failed" # Test 4: Provider service instance cache invalidation @@ -511,55 +513,55 @@ async def test_async_cache_warming(): from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker + from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession + from app.models.base import Base # Use SQLite in-memory database for testing - engine = create_engine("sqlite:///:memory:") + engine = create_async_engine("sqlite:///:memory:") # Create all tables Base.metadata.create_all(bind=engine) - testing = sessionmaker(autocommit=False, autoflush=False, bind=engine) - db = testing() - - try: - # Create test user and API key - user = User( - email="test@example.com", - username="testuser", - is_active=True, - hashed_password="dummy_hash", - ) - db.add(user) - db.commit() - - # Create a test API key - test_api_key = "test_key_123" - encrypted_key = encrypt_api_key(test_api_key) - - provider_key = ProviderKey( - user_id=user.id, - provider_name="test_provider", - encrypted_api_key=encrypted_key, - ) - db.add(provider_key) - db.commit() - - # Warm the cache - await warm_cache_async(db) - - # Verify user is cached with the correct API key - cached_user = await get_cached_user_async(test_api_key) - assert cached_user is not None - assert cached_user.id == user.id - - finally: - # Clean up - db.query(ProviderKey).delete() - db.query(User).delete() - db.commit() - db.close() - # Drop all tables - Base.metadata.drop_all(bind=engine) + testing = async_sessionmaker(autocommit=False, autoflush=False, bind=engine) + async with testing() as db: + try: + # Create test user and API key + user = User( + email="test@example.com", + username="testuser", + is_active=True, + hashed_password="dummy_hash", + ) + db.add(user) + db.commit() + + # Create a test API key + test_api_key = "test_key_123" + encrypted_key = encrypt_api_key(test_api_key) + + provider_key = ProviderKey( + user_id=user.id, + provider_name="test_provider", + encrypted_api_key=encrypted_key, + ) + db.add(provider_key) + db.commit() + + # Warm the cache + await warm_cache_async(db) + + # Verify user is cached with the correct API key + cached_user = await get_cached_user_async(test_api_key) + assert cached_user is not None + assert cached_user.id == user.id + + finally: + # Clean up + await db.execute(delete(ProviderKey)) + await db.execute(delete(User)) + await db.commit() + # Drop all tables + Base.metadata.drop_all(bind=engine) return True diff --git a/tests/cache/test_sync_cache.py b/tests/cache/test_sync_cache.py deleted file mode 100755 index a9a696d..0000000 --- a/tests/cache/test_sync_cache.py +++ /dev/null @@ -1,374 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify all types of caching in the system. -Tests user cache, provider service cache, provider keys cache, and model list cache. -""" - -import os -import sys -import time -from pathlib import Path -from unittest.mock import MagicMock - -from app.core.cache import ( - Cache, - invalidate_provider_models_cache, - invalidate_user_cache_by_id, - monitor_cache_performance, - provider_service_cache, - user_cache, - warm_cache, -) -from app.models.user import User -from app.services.provider_service import ProviderService - -# Add constants at the top of the file, after imports -EXPECTED_MODEL_COUNT = 2 # Expected number of models in test data -FIRST_MODEL_INDEX = 0 # Index of first model in test data -SECOND_MODEL_INDEX = 1 # Index of second model in test data - -# Add the project root to the Python path -script_dir = Path(__file__).resolve().parent.parent.parent -sys.path.insert(0, str(script_dir)) - -# Change to the project root directory -os.chdir(script_dir) - -# Set environment variables before importing cache modules -os.environ["FORCE_MEMORY_CACHE"] = "true" -os.environ["DEBUG_CACHE"] = "true" - - -# Clear caches before each test -def setup_function(function): - user_cache.clear() - provider_service_cache.clear() - ProviderService._models_cache = {} - ProviderService._models_cache_expiry = {} - ProviderService._models_l1_cache = {} - - -def test_basic_cache_operations(): - """Test basic cache operations""" - print("\n🔍 TESTING BASIC CACHE OPERATIONS") - print("================================") - - # Create a new cache instance - cache = Cache(ttl_seconds=5) - - # Test set and get - cache.set("test_key", "test_value") - value = cache.get("test_key") - assert value == "test_value", "Cache get/set failed" - - # Test TTL - cache.set("expiring_key", "expiring_value", ttl=1) - time.sleep(1.1) # Wait for TTL to expire - value = cache.get("expiring_key") - assert value is None, "TTL expiration failed" - - # Test delete - cache.set("delete_key", "delete_value") - cache.delete("delete_key") - value = cache.get("delete_key") - assert value is None, "Cache delete failed" - - # Test clear - cache.set("clear_key", "clear_value") - cache.clear() - value = cache.get("clear_key") - assert value is None, "Cache clear failed" - - print("✅ Basic cache operations test passed") - - -def test_user_cache(): - """Test user caching functionality""" - print("\n🔍 TESTING USER CACHE") - print("====================") - - # Create mock user - mock_user = User( - id=1, email="test@example.com", username="testuser", is_active=True - ) - - # Test caching user - api_key = "test_api_key_123" - user_cache.set(f"user:{api_key}", mock_user) - - # Test retrieving from cache - cached_user = user_cache.get(f"user:{api_key}") - assert cached_user is not None, "User cache get failed" - assert cached_user.id == mock_user.id, "Cached user ID mismatch" - assert cached_user.email == mock_user.email, "Cached user email mismatch" - - # Test cache invalidation - user_cache.delete(f"user:{api_key}") - cached_user = user_cache.get(f"user:{api_key}") - assert cached_user is None, "User cache invalidation failed" - - print("✅ User cache test passed") - - -def test_provider_keys_cache(): - """Test provider keys caching functionality""" - print("\n🔍 TESTING PROVIDER KEYS CACHE") - print("============================") - - # Create mock provider keys - mock_keys = { - "openai": { - "api_key": "sk_test_123", - "base_url": "https://api.openai.com/v1", - "model_mapping": {"gpt-4": "gpt-4-turbo"}, - }, - "anthropic": { - "api_key": "sk-ant-test-123", - "base_url": "https://api.anthropic.com/v1", - "model_mapping": {}, - }, - } - - # Test caching provider keys - user_id = 1 - cache_key = f"provider_keys:{user_id}" - provider_service_cache.set(cache_key, mock_keys, ttl=3600) - - # Test retrieving from cache - cached_keys = provider_service_cache.get(cache_key) - assert cached_keys is not None, "Provider keys cache get failed" - assert "openai" in cached_keys, "OpenAI provider key missing from cache" - assert "anthropic" in cached_keys, "Anthropic provider key missing from cache" - assert ( - cached_keys["openai"]["model_mapping"]["gpt-4"] == "gpt-4-turbo" - ), "Model mapping mismatch" - - # Test cache invalidation - provider_service_cache.delete(cache_key) - cached_keys = provider_service_cache.get(cache_key) - assert cached_keys is None, "Provider keys cache invalidation failed" - - print("✅ Provider keys cache test passed") - - -def test_model_list_cache(): - """Test model list caching functionality""" - print("\n🔍 TESTING MODEL LIST CACHE") - print("=========================") - - # Create mock model list - mock_models = [ - { - "id": "gpt-4", - "display_name": "GPT-4", - "object": "model", - "owned_by": "openai", - }, - { - "id": "claude-3", - "display_name": "Claude 3", - "object": "model", - "owned_by": "anthropic", - }, - ] - - # Test caching models - provider_name = "openai" - cache_key = "default" - ProviderService.cache_models(provider_name, cache_key, mock_models) - - # Test retrieving from cache - cached_models = ProviderService.get_cached_models(provider_name, cache_key) - assert cached_models is not None, "Model list cache get failed" - assert len(cached_models) == EXPECTED_MODEL_COUNT, "Model list length mismatch" - assert cached_models[FIRST_MODEL_INDEX]["id"] == "gpt-4", "Model ID mismatch" - assert cached_models[SECOND_MODEL_INDEX]["id"] == "claude-3", "Model ID mismatch" - - # Test cache invalidation - invalidate_provider_models_cache(provider_name) - cached_models = ProviderService.get_cached_models(provider_name, cache_key) - assert cached_models is None, "Model list cache invalidation failed" - - print("✅ Model list cache test passed") - - -def test_cache_invalidation(): - """Test cache invalidation scenarios""" - print("\n🔍 TESTING CACHE INVALIDATION") - print("===========================") - - # Test 1: User cache invalidation - print("\n🔄 Test 1: User cache invalidation") - mock_user = User( - id=1, email="test@example.com", username="testuser", is_active=True - ) - api_key = "test_api_key_123" - - # Set user in cache - user_cache.set(f"user:{api_key}", mock_user) - assert user_cache.get(f"user:{api_key}") is not None, "User cache set failed" - - # Invalidate user cache - user_cache.delete(f"user:{api_key}") - assert user_cache.get(f"user:{api_key}") is None, "User cache invalidation failed" - - # Test 2: Provider keys cache invalidation - print("\n🔄 Test 2: Provider keys cache invalidation") - mock_keys = { - "openai": { - "api_key": "sk_test_123", - "base_url": "https://api.openai.com/v1", - "model_mapping": {"gpt-4": "gpt-4-turbo"}, - } - } - user_id = 1 - cache_key = f"provider_keys:{user_id}" - - # Set provider keys in cache - provider_service_cache.set(cache_key, mock_keys, ttl=3600) - assert ( - provider_service_cache.get(cache_key) is not None - ), "Provider keys cache set failed" - - # Invalidate provider keys cache - provider_service_cache.delete(cache_key) - assert ( - provider_service_cache.get(cache_key) is None - ), "Provider keys cache invalidation failed" - - # Test 3: Model list cache invalidation - print("\n🔄 Test 3: Model list cache invalidation") - mock_models = [ - { - "id": "gpt-4", - "display_name": "GPT-4", - "object": "model", - "owned_by": "openai", - } - ] - provider_name = "openai" - model_cache_key = "default" - - # Set model list in cache - ProviderService.cache_models(provider_name, model_cache_key, mock_models) - assert ( - ProviderService.get_cached_models(provider_name, model_cache_key) is not None - ), "Model list cache set failed" - - # Invalidate model list cache - invalidate_provider_models_cache(provider_name) - assert ( - ProviderService.get_cached_models(provider_name, model_cache_key) is None - ), "Model list cache invalidation failed" - - -def test_cache_invalidation_by_id(): - """Test cache invalidation by user ID""" - print("\nTesting cache invalidation by user ID...") - - # Create test user - user = User( - id=1, - email="test@example.com", - username="testuser", - is_active=True, - hashed_password="dummy_hash", - ) - - # Create test API keys - api_key1 = "test_key_1" - api_key2 = "test_key_2" - - # Cache user with multiple API keys - user_cache.set(f"user:{api_key1}", user) - user_cache.set(f"user:{api_key2}", user) - - # Verify user is cached - cached_user1 = user_cache.get(f"user:{api_key1}") - cached_user2 = user_cache.get(f"user:{api_key2}") - assert cached_user1 is not None - assert cached_user2 is not None - assert cached_user1.id == user.id - assert cached_user2.id == user.id - - # Invalidate all cache entries for this user - invalidate_user_cache_by_id(user.id) - - # Verify cache is cleared - assert user_cache.get(f"user:{api_key1}") is None - assert user_cache.get(f"user:{api_key2}") is None - - -def test_provider_models_cache_invalidation(): - """Test provider models cache invalidation""" - print("\nTesting provider models cache invalidation...") - - # Set up test data - provider_name = "test_provider" - models = [{"id": "model1"}, {"id": "model2"}] - cache_key = "default" - - # Cache models using the public API - ProviderService.cache_models(provider_name, cache_key, models) - - # Verify models are cached - cached_models = ProviderService.get_cached_models(provider_name, cache_key) - assert cached_models is not None - assert len(cached_models) == 2 - assert cached_models[0]["id"] == "model1" - assert cached_models[1]["id"] == "model2" - - # Invalidate cache - invalidate_provider_models_cache(provider_name) - - # Verify cache is cleared - cached_models = ProviderService.get_cached_models(provider_name, cache_key) - assert cached_models is None, "Provider models cache invalidation failed" - - -def test_cache_stats_and_monitoring(): - """Test cache statistics and monitoring""" - print("\n🔍 TESTING CACHE STATS AND MONITORING") - print("===================================") - - # Test basic cache operations to generate stats - cache = Cache(ttl_seconds=5) - cache.set("test_key", "test_value") - cache.get("test_key") # Hit - cache.get("nonexistent") # Miss - - # Test stats - stats = cache.stats() - assert stats["hits"] == 1, "Cache hit count mismatch" - assert stats["misses"] == 1, "Cache miss count mismatch" - assert stats["total"] == 2, "Cache total count mismatch" - assert stats["hit_rate"] == 0.5, "Cache hit rate mismatch" - assert stats["entries"] == 1, "Cache entries count mismatch" - - # Test monitoring - monitoring = monitor_cache_performance() - assert "stats" in monitoring, "Cache stats missing" - assert "overall_hit_rate" in monitoring, "Overall hit rate missing" - assert "issues" in monitoring, "Issues list missing" - - print("✅ Cache stats and monitoring test passed") - - -async def test_cache_warming(): - """Test cache warming functionality""" - print("\n🔍 TESTING CACHE WARMING") - print("=======================") - - # Mock database session - db = MagicMock() - - # Test cache warming - await warm_cache(db) - - # Verify cache is populated - assert user_cache.stats()["entries"] > 0, "User cache not warmed" - assert ( - provider_service_cache.stats()["entries"] > 0 - ), "Provider service cache not warmed" - - print("✅ Cache warming test passed") diff --git a/tests/mock_testing/add_mock_provider.py b/tests/mock_testing/add_mock_provider.py index 6dc6c8f..795ac91 100755 --- a/tests/mock_testing/add_mock_provider.py +++ b/tests/mock_testing/add_mock_provider.py @@ -4,6 +4,7 @@ This allows users to test the Forge middleware without needing real API keys. """ +import asyncio import argparse import json import os @@ -13,10 +14,11 @@ from dotenv import load_dotenv from app.core.cache import provider_service_cache -from app.core.database import SessionLocal +from app.core.database import get_db_session from app.core.security import encrypt_api_key from app.models.provider_key import ProviderKey from app.models.user import User +from sqlalchemy import select # Add the project root to the Python path script_dir = Path(__file__).resolve().parent.parent.parent @@ -29,91 +31,86 @@ os.chdir(script_dir) -def setup_mock_provider(username: str, force: bool = False): +async def setup_mock_provider(username: str, force: bool = False): """Add a mock provider key to the specified user account""" # Create a database session - db = SessionLocal() + async with get_db_session() as db: + try: + # Find the user + result = await db.execute(select(User).filter(User.username == username)) + user = result.scalar_one_or_none() + if not user: + print(f"❌ User '{username}' not found. Please provide a valid username.") + return False + + # Check if the mock provider already exists for this user + result = await db.execute(select(ProviderKey).filter(ProviderKey.user_id == user.id, ProviderKey.provider_name == "mock")) + existing_provider = result.scalar_one_or_none() + + if existing_provider and not force: + print(f"⚠️ Mock provider already exists for user '{username}'.") + print("Use --force to replace it.") + return False + + # If force is set and provider exists, delete the existing one + if existing_provider and force: + db.delete(existing_provider) + db.commit() + print(f"🗑️ Deleted existing mock provider for user '{username}'.") + + # Create a mock API key - it doesn't need to be secure as it's not used + mock_api_key = "mock-api-key-for-testing-purposes" + encrypted_key = encrypt_api_key(mock_api_key) + + # Create model mappings for common models to their mock equivalents + model_mapping = { + "mock-only-gpt-3.5-turbo": "mock-gpt-3.5-turbo", + "mock-only-gpt-4": "mock-gpt-4", + "mock-only-gpt-4o": "mock-gpt-4o", + "mock-only-claude-3-opus": "mock-claude-3-opus", + "mock-only-claude-3-sonnet": "mock-claude-3-sonnet", + "mock-only-claude-3-haiku": "mock-claude-3-haiku", + } + + # Create the provider key + provider_key = ProviderKey( + user_id=user.id, + provider_name="mock", + encrypted_api_key=encrypted_key, + model_mapping=json.dumps( + model_mapping + ), # Use json.dumps for proper storage + ) + + db.add(provider_key) + db.commit() - try: - # Find the user - user = db.query(User).filter(User.username == username).first() - if not user: - print(f"❌ User '{username}' not found. Please provide a valid username.") + # Invalidate provider key cache for this user to force refresh + provider_service_cache.delete(f"provider_keys:{user.id}") + print(f"✅ Invalidated provider key cache for user '{username}'") + + print(f"✅ Successfully added mock provider for user '{username}'.") + print( + f"🔑 Mock API Key: {mock_api_key} (not a real key, used for testing only)" + ) + print("") + print("You can now use the following models with this provider:") + for original, mock in model_mapping.items(): + print(f" - {original} -> {mock}") + print("") + print( + "Use these models with your Forge API Key to test the middleware without real API calls." + ) + + return True + + except Exception as e: + await db.rollback() + print(f"❌ Error setting up mock provider: {str(e)}") return False - # Check if the mock provider already exists for this user - existing_provider = ( - db.query(ProviderKey) - .filter(ProviderKey.user_id == user.id, ProviderKey.provider_name == "mock") - .first() - ) - - if existing_provider and not force: - print(f"⚠️ Mock provider already exists for user '{username}'.") - print("Use --force to replace it.") - return False - # If force is set and provider exists, delete the existing one - if existing_provider and force: - db.delete(existing_provider) - db.commit() - print(f"🗑️ Deleted existing mock provider for user '{username}'.") - - # Create a mock API key - it doesn't need to be secure as it's not used - mock_api_key = "mock-api-key-for-testing-purposes" - encrypted_key = encrypt_api_key(mock_api_key) - - # Create model mappings for common models to their mock equivalents - model_mapping = { - "mock-only-gpt-3.5-turbo": "mock-gpt-3.5-turbo", - "mock-only-gpt-4": "mock-gpt-4", - "mock-only-gpt-4o": "mock-gpt-4o", - "mock-only-claude-3-opus": "mock-claude-3-opus", - "mock-only-claude-3-sonnet": "mock-claude-3-sonnet", - "mock-only-claude-3-haiku": "mock-claude-3-haiku", - } - - # Create the provider key - provider_key = ProviderKey( - user_id=user.id, - provider_name="mock", - encrypted_api_key=encrypted_key, - model_mapping=json.dumps( - model_mapping - ), # Use json.dumps for proper storage - ) - - db.add(provider_key) - db.commit() - - # Invalidate provider key cache for this user to force refresh - provider_service_cache.delete(f"provider_keys:{user.id}") - print(f"✅ Invalidated provider key cache for user '{username}'") - - print(f"✅ Successfully added mock provider for user '{username}'.") - print( - f"🔑 Mock API Key: {mock_api_key} (not a real key, used for testing only)" - ) - print("") - print("You can now use the following models with this provider:") - for original, mock in model_mapping.items(): - print(f" - {original} -> {mock}") - print("") - print( - "Use these models with your Forge API Key to test the middleware without real API calls." - ) - - return True - - except Exception as e: - db.rollback() - print(f"❌ Error setting up mock provider: {str(e)}") - return False - finally: - db.close() - - -def main(): +async def main(): parser = argparse.ArgumentParser( description="Add a mock provider to a user account for testing" ) @@ -136,4 +133,4 @@ def main(): if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/tests/mock_testing/setup_test_user.py b/tests/mock_testing/setup_test_user.py index b90811a..defe1df 100644 --- a/tests/mock_testing/setup_test_user.py +++ b/tests/mock_testing/setup_test_user.py @@ -4,6 +4,7 @@ This script is used to prepare the environment for testing the mock provider. """ +import asyncio import json import os import sys @@ -11,9 +12,10 @@ from dotenv import load_dotenv from passlib.context import CryptContext +from sqlalchemy import select from app.core.cache import invalidate_user_cache, provider_service_cache -from app.core.database import SessionLocal +from app.core.database import get_db_session from app.core.security import encrypt_api_key from app.models.provider_key import ProviderKey from app.models.user import User @@ -37,130 +39,123 @@ MOCK_PROVIDER_API_KEY = "mock-api-key-for-testing-purposes" -def create_or_update_test_user(): +async def create_or_update_test_user(): """Create a test user with a known Forge API key or update existing user""" - db = SessionLocal() - - try: - # Try to find user by username first - user = db.query(User).filter(User.username == TEST_USERNAME).first() - - # If not found by username, try by email - if not user: - user = db.query(User).filter(User.email == TEST_EMAIL).first() - - # If user exists, update the forge API key - if user: - print(f"✅ Found existing user: {user.username} (email: {user.email})") - old_key = user.forge_api_key - user.forge_api_key = TEST_FORGE_API_KEY - db.commit() - db.refresh(user) - # Invalidate the user in cache to force refresh with new API key - invalidate_user_cache(old_key) - invalidate_user_cache(TEST_FORGE_API_KEY) - print( - f"✅ Invalidated user cache for API keys: {old_key} and {TEST_FORGE_API_KEY}" + async with get_db_session() as db: + try: + # Try to find user by username first + result = await db.execute(select(User).filter(User.username == TEST_USERNAME)) + user = result.scalar_one_or_none() + + # If not found by username, try by email + if not user: + result = await db.execute(select(User).filter(User.email == TEST_EMAIL)) + user = result.scalar_one_or_none() + + # If user exists, update the forge API key + if user: + print(f"✅ Found existing user: {user.username} (email: {user.email})") + old_key = user.forge_api_key + user.forge_api_key = TEST_FORGE_API_KEY + await db.commit() + await db.refresh(user) + # Invalidate the user in cache to force refresh with new API key + invalidate_user_cache(old_key) + invalidate_user_cache(TEST_FORGE_API_KEY) + print( + f"✅ Invalidated user cache for API keys: {old_key} and {TEST_FORGE_API_KEY}" + ) + print(f"🔄 Updated Forge API Key: {old_key} -> {user.forge_api_key}") + return user + + # Create new user if not exists + hashed_password = pwd_context.hash(TEST_PASSWORD) + user = User( + username=TEST_USERNAME, + email=TEST_EMAIL, + hashed_password=hashed_password, + forge_api_key=TEST_FORGE_API_KEY, + is_active=True, ) - print(f"🔄 Updated Forge API Key: {old_key} -> {user.forge_api_key}") + db.add(user) + await db.commit() + await db.refresh(user) + print(f"✅ Created test user '{TEST_USERNAME}'") + print(f"🔑 Forge API Key: {TEST_FORGE_API_KEY}") return user - # Create new user if not exists - hashed_password = pwd_context.hash(TEST_PASSWORD) - user = User( - username=TEST_USERNAME, - email=TEST_EMAIL, - hashed_password=hashed_password, - forge_api_key=TEST_FORGE_API_KEY, - is_active=True, - ) - db.add(user) - db.commit() - db.refresh(user) - print(f"✅ Created test user '{TEST_USERNAME}'") - print(f"🔑 Forge API Key: {TEST_FORGE_API_KEY}") - return user - - except Exception as e: - db.rollback() - print(f"❌ Error creating/updating test user: {str(e)}") - return None - finally: - db.close() - - -def add_mock_provider_to_user(user_id): - """Add a mock provider to the test user""" - db = SessionLocal() - - try: - # Check if the mock provider already exists for this user - existing_provider = ( - db.query(ProviderKey) - .filter(ProviderKey.user_id == user_id, ProviderKey.provider_name == "mock") - .first() - ) + except Exception as e: + await db.rollback() + print(f"❌ Error creating/updating test user: {str(e)}") + return None - if existing_provider: - print("✅ Mock provider already exists for the test user.") - return True - # Create a mock API key - encrypted_key = encrypt_api_key(MOCK_PROVIDER_API_KEY) - - # Create model mappings for common models to their mock equivalents - model_mapping = { - "mock-only-gpt-3.5-turbo": "mock-gpt-3.5-turbo", - "mock-only-gpt-4": "mock-gpt-4", - "mock-only-gpt-4o": "mock-gpt-4o", - "mock-only-claude-3-opus": "mock-claude-3-opus", - "mock-only-claude-3-sonnet": "mock-claude-3-sonnet", - "mock-only-claude-3-haiku": "mock-claude-3-haiku", - } - - # Create the provider key - provider_key = ProviderKey( - user_id=user_id, - provider_name="mock", - encrypted_api_key=encrypted_key, - model_mapping=json.dumps(model_mapping), - ) +async def add_mock_provider_to_user(user_id): + """Add a mock provider to the test user""" + async with get_db_session() as db: + try: + # Check if the mock provider already exists for this user + result = await db.execute(select(ProviderKey).filter(ProviderKey.user_id == user_id, ProviderKey.provider_name == "mock")) + existing_provider = result.scalar_one_or_none() + + if existing_provider: + print("✅ Mock provider already exists for the test user.") + return True + + # Create a mock API key + encrypted_key = encrypt_api_key(MOCK_PROVIDER_API_KEY) + + # Create model mappings for common models to their mock equivalents + model_mapping = { + "mock-only-gpt-3.5-turbo": "mock-gpt-3.5-turbo", + "mock-only-gpt-4": "mock-gpt-4", + "mock-only-gpt-4o": "mock-gpt-4o", + "mock-only-claude-3-opus": "mock-claude-3-opus", + "mock-only-claude-3-sonnet": "mock-claude-3-sonnet", + "mock-only-claude-3-haiku": "mock-claude-3-haiku", + } + + # Create the provider key + provider_key = ProviderKey( + user_id=user_id, + provider_name="mock", + encrypted_api_key=encrypted_key, + model_mapping=json.dumps(model_mapping), + ) - db.add(provider_key) - db.commit() + db.add(provider_key) + await db.commit() - # Invalidate provider key cache for this user to force refresh - provider_service_cache.delete(f"provider_keys:{user_id}") - print(f"✅ Invalidated provider key cache for user ID: {user_id}") + # Invalidate provider key cache for this user to force refresh + provider_service_cache.delete(f"provider_keys:{user_id}") + print(f"✅ Invalidated provider key cache for user ID: {user_id}") - print("✅ Successfully added mock provider for test user.") - print(f"🔑 Mock API Key: {MOCK_PROVIDER_API_KEY} (used for testing only)") - print("") - print("You can now use the following models with this provider:") - for original, mock in model_mapping.items(): - print(f" - {original} -> {mock}") + print("✅ Successfully added mock provider for test user.") + print(f"🔑 Mock API Key: {MOCK_PROVIDER_API_KEY} (used for testing only)") + print("") + print("You can now use the following models with this provider:") + for original, mock in model_mapping.items(): + print(f" - {original} -> {mock}") - return True + return True - except Exception as e: - db.rollback() - print(f"❌ Error setting up mock provider: {str(e)}") - return False - finally: - db.close() + except Exception as e: + await db.rollback() + print(f"❌ Error setting up mock provider: {str(e)}") + return False -def main(): +async def main(): """Set up a test user with a mock provider""" print("🔄 Setting up test user with mock provider...") # Create or update test user - user = create_or_update_test_user() + user = await create_or_update_test_user() if not user: sys.exit(1) # Add mock provider to user - if add_mock_provider_to_user(user.id): + if await add_mock_provider_to_user(user.id): print("") print("✅ Setup complete!") print("📝 To test the mock provider, run:") @@ -173,4 +168,4 @@ def main(): if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/tests/unit_tests/test_provider_service.py b/tests/unit_tests/test_provider_service.py index 5ee8c58..51f8c6d 100644 --- a/tests/unit_tests/test_provider_service.py +++ b/tests/unit_tests/test_provider_service.py @@ -22,13 +22,13 @@ from app.models.user import User from app.services.provider_service import ProviderService from app.services.providers.adapter_factory import ProviderAdapterFactory -from app.core.cache import provider_service_cache, user_cache +from app.core.async_cache import async_provider_service_cache, async_user_cache class TestProviderService(TestCase): """Test the provider service""" - def setUp(self): + async def asyncSetUp(self): # Reset the adapters cache ProviderService._adapters_cache = {} @@ -38,63 +38,55 @@ def setUp(self): self.provider_key_openai.provider_name = "openai" self.provider_key_openai.encrypted_api_key = "encrypted_openai_key" self.provider_key_openai.base_url = None - self.provider_key_openai.model_mapping = json.dumps({"custom-gpt": "gpt-4"}) + self.provider_key_openai.model_mapping = {"custom-gpt": "gpt-4"} self.provider_key_anthropic = MagicMock(spec=ProviderKey) self.provider_key_anthropic.provider_name = "anthropic" self.provider_key_anthropic.encrypted_api_key = "encrypted_anthropic_key" self.provider_key_anthropic.base_url = None - self.provider_key_anthropic.model_mapping = "{}" + self.provider_key_anthropic.model_mapping = {} self.provider_key_google = MagicMock(spec=ProviderKey) self.provider_key_google.provider_name = "gemini" self.provider_key_google.encrypted_api_key = "encrypted_gemini_key" self.provider_key_google.base_url = None - self.provider_key_google.model_mapping = json.dumps( - {"test-gemini": "models/gemini-2.0-flash"} - ) + self.provider_key_google.model_mapping = {"test-gemini": "models/gemini-2.0-flash"} self.provider_key_xai = MagicMock(spec=ProviderKey) self.provider_key_xai.provider_name = "xai" self.provider_key_xai.encrypted_api_key = "encrypted_xai_key" self.provider_key_xai.base_url = None - self.provider_key_xai.model_mapping = json.dumps({"test-xai": "grok-2-1212"}) + self.provider_key_xai.model_mapping = {"test-xai": "grok-2-1212"} self.provider_key_fireworks = MagicMock(spec=ProviderKey) self.provider_key_fireworks.provider_name = "fireworks" self.provider_key_fireworks.encrypted_api_key = "encrypted_fireworks_key" self.provider_key_fireworks.base_url = None - self.provider_key_fireworks.model_mapping = json.dumps( - {"test-fireworks": "accounts/fireworks/models/code-llama-7b"} - ) + self.provider_key_fireworks.model_mapping = {"test-fireworks": "accounts/fireworks/models/code-llama-7b"} self.provider_key_openrouter = MagicMock(spec=ProviderKey) self.provider_key_openrouter.provider_name = "openrouter" self.provider_key_openrouter.encrypted_api_key = "encrypted_openrouter_key" self.provider_key_openrouter.base_url = None - self.provider_key_openrouter.model_mapping = json.dumps( - {"test-openrouter": "gpt-4o"} - ) + self.provider_key_openrouter.model_mapping = {"test-openrouter": "gpt-4o"} self.provider_key_together = MagicMock(spec=ProviderKey) self.provider_key_together.provider_name = "together" self.provider_key_together.encrypted_api_key = "encrypted_together_key" self.provider_key_together.base_url = None - self.provider_key_together.model_mapping = json.dumps( - {"test-together": "UAE-Large-V1"} - ) + self.provider_key_together.model_mapping = {"test-together": "UAE-Large-V1"} self.provider_key_azure = MagicMock(spec=ProviderKey) self.provider_key_azure.provider_name = "azure" self.provider_key_azure.encrypted_api_key = "encrypted_azure_key" self.provider_key_azure.base_url = "https://test-azure.openai.com" - self.provider_key_azure.model_mapping = json.dumps({"test-azure": "gpt-4o"}) + self.provider_key_azure.model_mapping = {"test-azure": "gpt-4o"} self.provider_key_bedrock = MagicMock(spec=ProviderKey) self.provider_key_bedrock.provider_name = "bedrock" self.provider_key_bedrock.encrypted_api_key = "encrypted_bedrock_key" self.provider_key_bedrock.base_url = None - self.provider_key_bedrock.model_mapping = json.dumps({"test-bedrock": "claude-3-5-sonnet-20240620-v1:0"}) + self.provider_key_bedrock.model_mapping = {"test-bedrock": "claude-3-5-sonnet-20240620-v1:0"} self.user.provider_keys = [ self.provider_key_openai, @@ -108,12 +100,12 @@ def setUp(self): self.provider_key_bedrock, ] - # Mock DB - self.db = MagicMock() + # Mock AsyncSession DB + self.db = AsyncMock() # Clear caches - provider_service_cache.clear() - user_cache.clear() + await async_provider_service_cache.clear() + await async_user_cache.clear() # Create the service with patched decrypt_api_key to avoid actual decryption with patch("app.services.provider_service.decrypt_api_key") as mock_decrypt: @@ -141,8 +133,11 @@ def setUp(self): # Mock user.id for the new constructor signature self.user.id = 1 - # Mock the database query for provider keys - self.db.query.return_value.filter.return_value.all.return_value = [ + # Mock the async database execute() pattern for provider keys + # Create mock result object + mock_result = MagicMock() # Result object should be sync, not AsyncMock + mock_scalars = MagicMock() # Don't use AsyncMock for scalars object + mock_scalars.all.return_value = [ self.provider_key_openai, self.provider_key_anthropic, self.provider_key_google, @@ -153,11 +148,13 @@ def setUp(self): self.provider_key_azure, self.provider_key_bedrock, ] + mock_result.scalars.return_value = mock_scalars # scalars() returns sync object + self.db.execute = AsyncMock(return_value=mock_result) # Only execute() is async self.service = ProviderService(self.user.id, self.db) # Pre-load the keys for testing - self.service._load_provider_keys() + await self.service._load_provider_keys_async() async def test_load_provider_keys(self): """Test loading provider keys""" diff --git a/tests/unit_tests/test_provider_service_images.py b/tests/unit_tests/test_provider_service_images.py index 1451585..4792d1e 100644 --- a/tests/unit_tests/test_provider_service_images.py +++ b/tests/unit_tests/test_provider_service_images.py @@ -23,40 +23,38 @@ from app.models.user import User from app.services.provider_service import ProviderService from app.services.providers.adapter_factory import ProviderAdapterFactory -from app.core.cache import provider_service_cache, user_cache +from app.core.async_cache import async_provider_service_cache, async_user_cache class TestProviderServiceImages(TestCase): """Test cases for ProviderService images endpoints""" - def setUp(self): + async def asyncSetUp(self): # Mock user with provider keys self.user = MagicMock(spec=User) self.provider_key_openai = MagicMock(spec=ProviderKey) self.provider_key_openai.provider_name = "openai" self.provider_key_openai.encrypted_api_key = "encrypted_openai_key" self.provider_key_openai.base_url = None - self.provider_key_openai.model_mapping = json.dumps({"dall-e-2": "dall-e-2"}) + self.provider_key_openai.model_mapping = {"dall-e-2": "dall-e-2"} self.provider_key_anthropic = MagicMock(spec=ProviderKey) self.provider_key_anthropic.provider_name = "anthropic" self.provider_key_anthropic.encrypted_api_key = "encrypted_anthropic_key" self.provider_key_anthropic.base_url = None - self.provider_key_anthropic.model_mapping = json.dumps( - {"custom-anthropic": "claude-3-opus", "claude-3-opus": "claude-3-opus"} - ) + self.provider_key_anthropic.model_mapping = {"custom-anthropic": "claude-3-opus", "claude-3-opus": "claude-3-opus"} self.user.provider_keys = [ self.provider_key_openai, self.provider_key_anthropic, ] - # Mock DB - self.db = MagicMock() + # Mock AsyncSession DB + self.db = AsyncMock() # Clear caches - provider_service_cache.clear() - user_cache.clear() + await async_provider_service_cache.clear() + await async_user_cache.clear() # Remove ProviderService creation from setUp # It will be created in each test after patching @@ -85,15 +83,19 @@ async def test_process_request_images_generations_routing( # Create the service with the NEW constructor signature (user.id) self.user.id = 1 - # Mock the database query that the new loading mechanism uses - self.db.query.return_value.filter.return_value.all.return_value = [ + # Mock the async database execute() pattern for provider keys + mock_result = MagicMock() # Result object should be sync, not AsyncMock + mock_scalars = MagicMock() # Don't use AsyncMock for scalars object + mock_scalars.all.return_value = [ self.provider_key_openai, self.provider_key_anthropic, ] + mock_result.scalars.return_value = mock_scalars # scalars() returns sync object + self.db.execute = AsyncMock(return_value=mock_result) # Only execute() is async service = ProviderService(self.user.id, self.db) # Let the service load keys properly through the new mechanism - service._load_provider_keys() + await service._load_provider_keys_async() # mock openai image generation response # no need to mock the response for anthropic @@ -151,15 +153,19 @@ async def test_process_request_images_edits_routing( # Create the service with the NEW constructor signature (user.id) self.user.id = 1 - # Mock the database query that the new loading mechanism uses - self.db.query.return_value.filter.return_value.all.return_value = [ + # Mock the async database execute() pattern for provider keys + mock_result = MagicMock() # Result object should be sync, not AsyncMock + mock_scalars = MagicMock() # Don't use AsyncMock for scalars object + mock_scalars.all.return_value = [ self.provider_key_openai, self.provider_key_anthropic, ] + mock_result.scalars.return_value = mock_scalars # scalars() returns sync object + self.db.execute = AsyncMock(return_value=mock_result) # Only execute() is async service = ProviderService(self.user.id, self.db) # Let the service load keys properly through the new mechanism - service._load_provider_keys() + await service._load_provider_keys_async() # mock openai image edits response # no need to mock the response for anthropic diff --git a/tools/diagnostics/fix_model_mapping.py b/tools/diagnostics/fix_model_mapping.py index d883a30..97e5a76 100755 --- a/tools/diagnostics/fix_model_mapping.py +++ b/tools/diagnostics/fix_model_mapping.py @@ -4,11 +4,12 @@ Specifically for fixing the gpt-4o to mock-gpt-4o mapping issue. """ +import asyncio import os import sys from pathlib import Path -from app.core.database import get_db +from app.core.database import get_async_db # Add the project root to the Python path script_dir = Path(__file__).resolve().parent.parent.parent @@ -18,13 +19,14 @@ os.chdir(script_dir) -def fix_model_mappings(): +async def fix_model_mappings(): """Fix model mappings by clearing caches""" print("\n🔧 FIXING MODEL MAPPINGS") print("======================") # Get DB session - next(get_db()) + async with get_async_db() as db: + pass # Clear all caches to ensure changes take effect print("🔄 Invalidating provider service cache for all users") @@ -38,9 +40,9 @@ def fix_model_mappings(): return True -def main(): +async def main(): """Main entry point""" - if fix_model_mappings(): + if await fix_model_mappings(): print( "\n✅ Model mappings have been fixed. Use check_model_mappings.py to verify." ) @@ -51,4 +53,4 @@ def main(): if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/uv.lock b/uv.lock index a45c8d3..32eea87 100644 --- a/uv.lock +++ b/uv.lock @@ -532,7 +532,7 @@ dependencies = [ { name = "python-multipart" }, { name = "redis" }, { name = "requests" }, - { name = "sqlalchemy" }, + { name = "sqlalchemy", extra = ["asyncio"] }, { name = "svix" }, { name = "uvicorn" }, ] @@ -580,7 +580,7 @@ requires-dist = [ { name = "redis", specifier = ">=4.6.0" }, { name = "requests", specifier = ">=2.28.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.2.0" }, - { name = "sqlalchemy", specifier = ">=2.0.0" }, + { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.0" }, { name = "svix", specifier = ">=1.13.0" }, { name = "uvicorn", specifier = ">=0.22.0" }, ] @@ -1394,6 +1394,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/7c/5fc8e802e7506fe8b55a03a2e1dab156eae205c91bee46305755e086d2e2/sqlalchemy-2.0.40-py3-none-any.whl", hash = "sha256:32587e2e1e359276957e6fe5dad089758bc042a971a8a09ae8ecf7a8fe23d07a", size = 1903894 }, ] +[package.optional-dependencies] +asyncio = [ + { name = "greenlet" }, +] + [[package]] name = "starlette" version = "0.46.1" From bf398bc6d35c2b0a6e8dd657a4ca1efca16a51f7 Mon Sep 17 00:00:00 2001 From: charles Date: Wed, 9 Jul 2025 21:47:41 -0700 Subject: [PATCH 08/10] Add vertex ai support --- app/api/schemas/provider_key.py | 42 +-- app/exceptions/exceptions.py | 2 +- app/main.py | 2 +- app/services/providers/adapter_factory.py | 4 + app/services/providers/anthropic_adapter.py | 69 +++-- app/services/providers/base.py | 41 ++- app/services/providers/vertex_adapter.py | 216 ++++++++++++++ pyproject.toml | 2 + uv.lock | 308 ++++++++++++++++++++ 9 files changed, 613 insertions(+), 73 deletions(-) create mode 100644 app/services/providers/vertex_adapter.py diff --git a/app/api/schemas/provider_key.py b/app/api/schemas/provider_key.py index 17dbc57..1a44588 100644 --- a/app/api/schemas/provider_key.py +++ b/app/api/schemas/provider_key.py @@ -9,46 +9,6 @@ logger = get_logger(name="provider_key") -# Constants for API key masking -API_KEY_MASK_PREFIX_LENGTH = 2 -API_KEY_MASK_SUFFIX_LENGTH = 4 -# Minimum length to apply the full prefix + suffix mask (e.g., pr******fix) -# This means if length is > (PREFIX + SUFFIX), we can apply the full rule. -MIN_KEY_LENGTH_FOR_FULL_MASK_LOGIC = ( - API_KEY_MASK_PREFIX_LENGTH + API_KEY_MASK_SUFFIX_LENGTH -) - - -# Helper function for masking API keys -def _mask_api_key_value(value: str | None) -> str | None: - if not value: - return None - - length = len(value) - - if length == 0: - return "" - - # If key is too short for any meaningful prefix/suffix masking - if length <= API_KEY_MASK_PREFIX_LENGTH: - return "*" * length - - # If key is long enough for prefix, but not for prefix + suffix - # e.g., length is 3, 4, 5, 6. For these, show prefix and mask the rest. - if length <= MIN_KEY_LENGTH_FOR_FULL_MASK_LOGIC: - return value[:API_KEY_MASK_PREFIX_LENGTH] + "*" * ( - length - API_KEY_MASK_PREFIX_LENGTH - ) - - # If key is long enough for the full prefix + ... + suffix mask - # number of asterisks = length - prefix_length - suffix_length - num_asterisks = length - API_KEY_MASK_PREFIX_LENGTH - API_KEY_MASK_SUFFIX_LENGTH - return ( - value[:API_KEY_MASK_PREFIX_LENGTH] - + "*" * num_asterisks - + value[-API_KEY_MASK_SUFFIX_LENGTH:] - ) - class ProviderKeyBase(BaseModel): provider_name: str = Field(..., min_length=1) @@ -109,7 +69,7 @@ def api_key(self) -> str | None: api_key, _ = provider_adapter_cls.deserialize_api_key_config( decrypted_value ) - return _mask_api_key_value(api_key) + return provider_adapter_cls.mask_api_key(api_key) except Exception as e: logger.error( f"Error deserializing API key for provider {self.provider_name}: {e}" diff --git a/app/exceptions/exceptions.py b/app/exceptions/exceptions.py index 112527f..0851e9f 100644 --- a/app/exceptions/exceptions.py +++ b/app/exceptions/exceptions.py @@ -80,4 +80,4 @@ def __init__(self, error: Exception): class InvalidForgeKeyException(BaseInvalidForgeKeyException): """Exception raised when a Forge key is invalid.""" def __init__(self, error: Exception): - super().__init__(error) \ No newline at end of file + super().__init__(error) diff --git a/app/main.py b/app/main.py index 85364a6..9556c0c 100644 --- a/app/main.py +++ b/app/main.py @@ -94,7 +94,7 @@ async def provider_authentication_exception_handler(request: Request, exc: Provi async def invalid_provider_exception_handler(request: Request, exc: InvalidProviderException): return HTTPException( status_code=400, - detail=f"{str(exc)}. Please verify your provider and model details by calling the /models endpoint or visiting https://tensorblock.co/api-docs/model-ids, and ensure you’re using a valid provider name, model name, and model ID." + detail=f"{str(exc)}. Please verify your provider and model details by calling the /models endpoint or visiting https://tensorblock.co/api-docs/model-ids, and ensure you're using a valid provider name, model name, and model ID." ) # Add exception handler for BaseInvalidProviderSetupException diff --git a/app/services/providers/adapter_factory.py b/app/services/providers/adapter_factory.py index 6a2f0d5..e98bb32 100644 --- a/app/services/providers/adapter_factory.py +++ b/app/services/providers/adapter_factory.py @@ -13,6 +13,7 @@ from .perplexity_adapter import PerplexityAdapter from .tensorblock_adapter import TensorblockAdapter from .zhipu_adapter import ZhipuAdapter +from .vertex_adapter import VertexAdapter class ProviderAdapterFactory: @@ -185,6 +186,9 @@ class ProviderAdapterFactory: "bedrock": { "adapter": BedrockAdapter, }, + "vertex": { + "adapter": VertexAdapter, + }, "customized": { "adapter": OpenAIAdapter, }, diff --git a/app/services/providers/anthropic_adapter.py b/app/services/providers/anthropic_adapter.py index 4533d31..3e41db1 100644 --- a/app/services/providers/anthropic_adapter.py +++ b/app/services/providers/anthropic_adapter.py @@ -3,7 +3,7 @@ import uuid from collections.abc import AsyncGenerator from http import HTTPStatus -from typing import Any +from typing import Any, Callable import aiohttp @@ -16,6 +16,8 @@ ANTHROPIC_DEFAULT_MAX_TOKENS = 4096 +logger = get_logger(name="anthropic_adapter") + class AnthropicAdapter(ProviderAdapter): """Adapter for Anthropic API""" @@ -118,22 +120,10 @@ async def list_models(self, api_key: str) -> list[str]: self.cache_models(api_key, self._base_url, models) return models - - async def process_completion( - self, - endpoint: str, - payload: dict[str, Any], - api_key: str, - ) -> Any: - """Process a completion request using Anthropic API""" - headers = { - "x-api-key": api_key, - "Content-Type": "application/json", - "anthropic-version": "2023-06-01", - } - - # Convert OpenAI format to Anthropic format - streaming = payload.get("stream", False) + + @staticmethod + def convert_openai_payload_to_anthropic(payload: dict[str, Any]) -> dict[str, Any]: + """Convert Anthropic completion payload to OpenAI format""" anthropic_payload = { "model": payload["model"], "max_tokens": payload.get("max_completion_tokens", payload.get("max_tokens", ANTHROPIC_DEFAULT_MAX_TOKENS)), @@ -150,7 +140,7 @@ async def process_completion( for msg in payload["messages"]: role = msg["role"] content = msg["content"] - content = self.convert_openai_content_to_anthropic(content) + content = AnthropicAdapter.convert_openai_content_to_anthropic(content) if role == "system": # Anthropic requires a system message to be string @@ -171,6 +161,25 @@ async def process_completion( # Handle regular completion (legacy format) anthropic_payload["prompt"] = f"Human: {payload['prompt']}\n\nAssistant: " + return anthropic_payload + + async def process_completion( + self, + endpoint: str, + payload: dict[str, Any], + api_key: str, + ) -> Any: + """Process a completion request using Anthropic API""" + headers = { + "x-api-key": api_key, + "Content-Type": "application/json", + "anthropic-version": "2023-06-01", + } + + streaming = payload.get("stream", False) + # Convert OpenAI format to Anthropic format + anthropic_payload = self.convert_openai_payload_to_anthropic(payload) + # Choose the appropriate API endpoint - using ternary operator api_endpoint = "messages" if "messages" in anthropic_payload else "complete" @@ -179,17 +188,18 @@ async def process_completion( # Handle streaming requests if streaming and "messages" in anthropic_payload: anthropic_payload["stream"] = True - return await self._stream_anthropic_response( + return await self.stream_anthropic_response( url, headers, anthropic_payload, payload["model"] ) else: # For non-streaming, use the regular approach - return await self._process_regular_response( + return await self.process_regular_response( url, headers, anthropic_payload, payload["model"] ) - async def _stream_anthropic_response( - self, url, headers, anthropic_payload, model_name + @staticmethod + async def stream_anthropic_response( + url, headers, anthropic_payload, model_name, error_handler: Callable[[str, int], Any] | None = None ): """Handle streaming response from Anthropic API, including usage data.""" @@ -206,9 +216,9 @@ async def stream_response() -> AsyncGenerator[bytes, None]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Completion Streaming API error for {self.provider_name}: {error_text}") + logger.error(f"Completion Streaming API error for anthropic: {error_text}") raise ProviderAPIException( - provider_name=self.provider_name, + provider_name="anthropic", error_code=response.status, error_message=error_text ) @@ -327,7 +337,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]: usage_info_complete = False except json.JSONDecodeError as e: - logger.warning(f"Stream API error for {self.provider_name}: Failed to parse JSON: {e}") + logger.warning(f"Stream API error for anthropic: Failed to parse JSON: {e}") continue except Exception as e: continue @@ -337,8 +347,9 @@ async def stream_response() -> AsyncGenerator[bytes, None]: return stream_response() - async def _process_regular_response( - self, url, headers, anthropic_payload, model_name + @staticmethod + async def process_regular_response( + url: str, headers: dict[str, str], anthropic_payload: dict[str, Any], model_name: str, error_handler: Callable[[str, int], Any] | None = None ): """Handle regular (non-streaming) response from Anthropic API""" # Single with statement for multiple contexts @@ -348,9 +359,9 @@ async def _process_regular_response( ): if response.status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Completion API error for {self.provider_name}: {error_text}") + logger.error(f"Completion API error for anthropic: {error_text}") raise ProviderAPIException( - provider_name=self.provider_name, + provider_name="anthropic", error_code=response.status, error_message=error_text ) diff --git a/app/services/providers/base.py b/app/services/providers/base.py index 05b6d4c..0ccb392 100644 --- a/app/services/providers/base.py +++ b/app/services/providers/base.py @@ -1,8 +1,16 @@ -import json import time from abc import ABC, abstractmethod from typing import Any, ClassVar +# Constants for API key masking +API_KEY_MASK_PREFIX_LENGTH = 2 +API_KEY_MASK_SUFFIX_LENGTH = 4 +# Minimum length to apply the full prefix + suffix mask (e.g., pr******fix) +# This means if length is > (PREFIX + SUFFIX), we can apply the full rule. +MIN_KEY_LENGTH_FOR_FULL_MASK_LOGIC = ( + API_KEY_MASK_PREFIX_LENGTH + API_KEY_MASK_SUFFIX_LENGTH +) + class ProviderAdapter(ABC): """Base class for all provider adapters""" @@ -66,6 +74,37 @@ def deserialize_api_key_config(serialized_api_key_config: str) -> tuple[str, dic def mask_config(config: dict[str, Any]) -> dict[str, Any]: """Mask the config for the given provider""" return config + + @staticmethod + def mask_api_key(api_key: str) -> str: + """Mask the API key for the given provider""" + if not api_key: + return None + + length = len(api_key) + + if length == 0: + return "" + + # If key is too short for any meaningful prefix/suffix masking + if length <= API_KEY_MASK_PREFIX_LENGTH: + return "*" * length + + # If key is long enough for prefix, but not for prefix + suffix + # e.g., length is 3, 4, 5, 6. For these, show prefix and mask the rest. + if length <= MIN_KEY_LENGTH_FOR_FULL_MASK_LOGIC: + return api_key[:API_KEY_MASK_PREFIX_LENGTH] + "*" * ( + length - API_KEY_MASK_PREFIX_LENGTH + ) + + # If key is long enough for the full prefix + ... + suffix mask + # number of asterisks = length - prefix_length - suffix_length + num_asterisks = length - API_KEY_MASK_PREFIX_LENGTH - API_KEY_MASK_SUFFIX_LENGTH + return ( + api_key[:API_KEY_MASK_PREFIX_LENGTH] + + "*" * num_asterisks + + api_key[-API_KEY_MASK_SUFFIX_LENGTH:] + ) def cache_models( self, api_key: str, base_url: str | None, models: list[str] diff --git a/app/services/providers/vertex_adapter.py b/app/services/providers/vertex_adapter.py new file mode 100644 index 0000000..2023c71 --- /dev/null +++ b/app/services/providers/vertex_adapter.py @@ -0,0 +1,216 @@ +import asyncio +import json +from collections.abc import AsyncGenerator +from typing import Any +import aiohttp +from google.oauth2 import service_account +from google.auth.transport.requests import Request +from app.exceptions.exceptions import ProviderAuthenticationException, InvalidProviderConfigException, InvalidProviderAPIKeyException, ProviderAPIException + +from app.core.logger import get_logger + +from .base import ProviderAdapter +from .anthropic_adapter import AnthropicAdapter + +logger = get_logger(name="vertex_adapter") + + +class VertexAdapter(ProviderAdapter): + """Adapter for Vertex AI API""" + + def __init__(self, provider_name: str, base_url: str | None = None, config: dict[str, str] | None = None): + self._provider_name = provider_name + self._base_url = base_url.rstrip("/") if base_url else None + self.config = config + self.parse_config(config) + + @property + def provider_name(self) -> str: + return self._provider_name + + @staticmethod + def validate_config(config: dict[str, str] | None): + """Validate the config for the given provider""" + try: + assert config is not None + assert config.get("publisher", "anthropic") is not None + assert config.get("location") is not None + except Exception as e: + raise InvalidProviderConfigException("Vertex", e) + + def parse_config(self, config: dict[str, str] | None): + """Validate the config for the given provider""" + self.validate_config(config) + self.publisher = config.get("publisher", "anthropic").lower() + self.location = config["location"].lower() + + @staticmethod + def validate_api_key(api_key: str): + """Validate the API key for the given provider""" + try: + cred_json = json.loads(api_key) + assert cred_json["type"] == "service_account" + assert cred_json["project_id"] is not None + assert cred_json["private_key_id"] is not None + assert cred_json["private_key"] is not None + assert cred_json["client_email"] is not None + assert cred_json["client_id"] is not None + assert cred_json["auth_uri"] is not None + assert cred_json["token_uri"] is not None + assert cred_json["auth_provider_x509_cert_url"] is not None + assert cred_json["client_x509_cert_url"] is not None + assert cred_json["universe_domain"] is not None + + return cred_json + except Exception as e: + raise InvalidProviderAPIKeyException("Vertex", e) + + def parse_api_key(self, api_key: str): + """Validate the API key for the given provider""" + try: + cred_json = self.validate_api_key(api_key) + self.project_id = cred_json["project_id"] + self.cred_json = cred_json + except Exception as e: + raise ProviderAuthenticationException("Vertex", e) + + @staticmethod + def serialize_api_key_config(api_key: str, config: dict[str, Any] | None) -> str: + """Serialize the API key for the given provider""" + VertexAdapter.validate_api_key(api_key) + VertexAdapter.validate_config(config) + return json.dumps({ + "api_key": api_key, + "publisher": config.get("publisher", "anthropic"), + "location": config["location"], + }) + + @staticmethod + def deserialize_api_key_config(serialized_api_key_config: str) -> tuple[str, dict[str, Any] | None]: + """Deserialize the API key for the given provider""" + deserialized_api_key_config = json.loads(serialized_api_key_config) + return deserialized_api_key_config["api_key"], { + "publisher": deserialized_api_key_config["publisher"], + "location": deserialized_api_key_config["location"], + } + + @staticmethod + def mask_config(config: dict[str, Any] | None) -> dict[str, Any] | None: + """Mask the config for the given provider""" + VertexAdapter.validate_config(config) + return { + "publisher": config.get("publisher", "anthropic"), + "location": config["location"], + } + + @staticmethod + def mask_api_key(api_key: str) -> str: + """Mask the API key for the given provider""" + cred_json = VertexAdapter.validate_api_key(api_key) + return json.dumps({ + "type": cred_json["type"], + "project_id": ProviderAdapter.mask_api_key(cred_json["project_id"]), + "private_key_id": ProviderAdapter.mask_api_key(cred_json["private_key_id"]), + "private_key": ProviderAdapter.mask_api_key(cred_json["private_key"]), + "client_email": ProviderAdapter.mask_api_key(cred_json["client_email"]), + "client_id": ProviderAdapter.mask_api_key(cred_json["client_id"]), + "auth_uri": cred_json["auth_uri"], + "token_uri": cred_json["token_uri"], + "auth_provider_x509_cert_url": cred_json["auth_provider_x509_cert_url"], + "client_x509_cert_url": ProviderAdapter.mask_api_key(cred_json["client_x509_cert_url"]), + "universe_domain": cred_json["universe_domain"], + }) + + async def vertex_authentication(self, api_key: str) -> str: + # validate api key + self.parse_api_key(api_key) + + # load credentials within scope + try: + credentials = service_account.Credentials.from_service_account_info(self.cred_json, scopes=["https://www.googleapis.com/auth/cloud-platform"]) + + # refresh token - run in thread pool to avoid blocking + await asyncio.to_thread(credentials.refresh, Request()) + return credentials.token + except Exception as e: + logger.error(f"Error authenticating with Vertex API: {e}") + raise ProviderAuthenticationException("Vertex", e) + + async def list_models(self, api_key: str) -> list[str]: + """List all models (verbosely) supported by the provider""" + # Check cache first + cached_models = self.get_cached_models(api_key, self._base_url) + if cached_models is not None: + return cached_models + + token = await self.vertex_authentication(api_key) + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + url = f"{self._base_url}/v1beta1/publishers/{self.publisher}/models" + models = [] + async with aiohttp.ClientSession() as session: + next_page_token = "###initial" + while next_page_token: + params = {} + if next_page_token and next_page_token != "###initial": + params["pageToken"] = next_page_token + async with session.get(url, headers=headers, params=params) as response: + results = await response.json() + next_page_token = results.get("nextPageToken") + for m in results["publisherModels"]: + name = m["name"] + version_id = m["versionId"] + model_id = f"{name.split('/')[-1]}@{version_id}" + models.append(model_id) + + self.cache_models(api_key, self._base_url, models) + return models + + async def process_completion(self, endpoint: str, payload: dict[str, Any], api_key: str) -> Any: + token = await self.vertex_authentication(api_key) + headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + + streaming = payload.get("stream", False) + model_name = payload["model"] + anthropic_payload = AnthropicAdapter.convert_openai_payload_to_anthropic(payload) + + # vertex specific payload + anthropic_payload["anthropic_version"] = "vertex-2023-10-16" + del anthropic_payload["model"] + + def error_handler(error_text: str, http_status: int): + try: + error_json = json.loads(error_text) + error_message = error_json.get("error", {}).get("message", "Unknown error") + error_code = error_json.get("error", {}).get("code", http_status) + raise ProviderAPIException("Vertex", error_code, error_message) + except Exception: + raise ProviderAPIException("Vertex", http_status, error_text) + + if streaming: + # https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamRawPredict + # vertex doesn't do actual streaming, it just returns a stream of json objects + url = f"{self._base_url}/v1/projects/{self.project_id}/locations/{self.location}/publishers/{self.publisher}/models/{model_name}:streamRawPredict" + async def custom_stream_response(url, headers, anthropic_payload, model_name): + async def stream_response() -> AsyncGenerator[bytes, None]: + resp = await AnthropicAdapter.process_regular_response(url, headers, anthropic_payload, model_name, error_handler) + resp['object'] = 'chat.completion.chunk' + for choice in resp['choices']: + choice['delta'] = choice['message'] + del choice['message'] + yield f"data: {json.dumps(resp)}\n\n".encode() + yield b"data: [DONE]\n\n" + return stream_response() + return await custom_stream_response(url, headers, anthropic_payload, model_name) + else: + url = f"{self._base_url}/v1/projects/{self.project_id}/locations/{self.location}/publishers/{self.publisher}/models/{model_name}:rawPredict" + return await AnthropicAdapter.process_regular_response(url, headers, anthropic_payload, model_name, error_handler) + + async def process_embeddings(self, payload: dict[str, Any]) -> Any: + """Process a embeddings request using Vertex API""" + raise NotImplementedError("Embedding for Vertex is not supported") diff --git a/pyproject.toml b/pyproject.toml index 5075668..7530a5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,8 @@ dependencies = [ "redis>=4.6.0", # sync & async clients used by shared cache "loguru>=0.7.0", "aiobotocore~=2.0", + "google-generativeai>=0.3.0", + "google-genai>=0.3.0", ] [project.optional-dependencies] diff --git a/uv.lock b/uv.lock index 32eea87..cbcefc2 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,10 @@ version = 1 revision = 1 requires-python = ">=3.12" +resolution-markers = [ + "python_full_version >= '3.13'", + "python_full_version < '3.13'", +] [[package]] name = "aiobotocore" @@ -235,6 +239,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7e/83/a753562020b69fa90cebc39e8af2c753b24dcdc74bee8355ee3f6cefdf34/botocore-1.38.27-py3-none-any.whl", hash = "sha256:a785d5e9a5eda88ad6ab9ed8b87d1f2ac409d0226bba6ff801c55359e94d91a8", size = 13580545 }, ] +[[package]] +name = "cachetools" +version = "5.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080 }, +] + [[package]] name = "certifi" version = "2025.1.31" @@ -522,6 +535,8 @@ dependencies = [ { name = "cryptography" }, { name = "email-validator" }, { name = "fastapi" }, + { name = "google-genai" }, + { name = "google-generativeai" }, { name = "gunicorn" }, { name = "loguru" }, { name = "passlib" }, @@ -563,6 +578,8 @@ requires-dist = [ { name = "email-validator", specifier = ">=2.0.0" }, { name = "fastapi", specifier = ">=0.95.0" }, { name = "flake8", marker = "extra == 'dev'" }, + { name = "google-genai", specifier = ">=0.3.0" }, + { name = "google-generativeai", specifier = ">=0.3.0" }, { name = "gunicorn", specifier = ">=20.0.0" }, { name = "isort", marker = "extra == 'dev'" }, { name = "loguru", specifier = ">=0.7.0" }, @@ -625,6 +642,135 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/c8/a5be5b7550c10858fcf9b0ea054baccab474da77d37f1e828ce043a3a5d4/frozenlist-1.5.0-py3-none-any.whl", hash = "sha256:d994863bba198a4a518b467bb971c56e1db3f180a25c6cf7bb1949c267f748c3", size = 11901 }, ] +[[package]] +name = "google-ai-generativelanguage" +version = "0.6.15" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core", extra = ["grpc"] }, + { name = "google-auth" }, + { name = "proto-plus" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/d1/48fe5d7a43d278e9f6b5ada810b0a3530bbeac7ed7fcbcd366f932f05316/google_ai_generativelanguage-0.6.15.tar.gz", hash = "sha256:8f6d9dc4c12b065fe2d0289026171acea5183ebf2d0b11cefe12f3821e159ec3", size = 1375443 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/a3/67b8a6ff5001a1d8864922f2d6488dc2a14367ceb651bc3f09a947f2f306/google_ai_generativelanguage-0.6.15-py3-none-any.whl", hash = "sha256:5a03ef86377aa184ffef3662ca28f19eeee158733e45d7947982eb953c6ebb6c", size = 1327356 }, +] + +[[package]] +name = "google-api-core" +version = "2.25.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "googleapis-common-protos" }, + { name = "proto-plus" }, + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/21/e9d043e88222317afdbdb567165fdbc3b0aad90064c7e0c9eb0ad9955ad8/google_api_core-2.25.1.tar.gz", hash = "sha256:d2aaa0b13c78c61cb3f4282c464c046e45fbd75755683c9c525e6e8f7ed0a5e8", size = 165443 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/4b/ead00905132820b623732b175d66354e9d3e69fcf2a5dcdab780664e7896/google_api_core-2.25.1-py3-none-any.whl", hash = "sha256:8a2a56c1fef82987a524371f99f3bd0143702fecc670c72e600c1cda6bf8dbb7", size = 160807 }, +] + +[package.optional-dependencies] +grpc = [ + { name = "grpcio" }, + { name = "grpcio-status" }, +] + +[[package]] +name = "google-api-python-client" +version = "2.176.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "google-auth" }, + { name = "google-auth-httplib2" }, + { name = "httplib2" }, + { name = "uritemplate" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/38/daf70faf6d05556d382bac640bc6765f09fcfb9dfb51ac4a595d3453a2a9/google_api_python_client-2.176.0.tar.gz", hash = "sha256:2b451cdd7fd10faeb5dd20f7d992f185e1e8f4124c35f2cdcc77c843139a4cf1", size = 13154773 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/2c/758f415a19a12c3c6d06902794b0dd4c521d912a59b98ab752bba48812df/google_api_python_client-2.176.0-py3-none-any.whl", hash = "sha256:e22239797f1d085341e12cd924591fc65c56d08e0af02549d7606092e6296510", size = 13678445 }, +] + +[[package]] +name = "google-auth" +version = "2.40.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools" }, + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/9b/e92ef23b84fa10a64ce4831390b7a4c2e53c0132568d99d4ae61d04c8855/google_auth-2.40.3.tar.gz", hash = "sha256:500c3a29adedeb36ea9cf24b8d10858e152f2412e3ca37829b3fa18e33d63b77", size = 281029 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/63/b19553b658a1692443c62bd07e5868adaa0ad746a0751ba62c59568cd45b/google_auth-2.40.3-py2.py3-none-any.whl", hash = "sha256:1370d4593e86213563547f97a92752fc658456fe4514c809544f330fed45a7ca", size = 216137 }, +] + +[[package]] +name = "google-auth-httplib2" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "httplib2" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/56/be/217a598a818567b28e859ff087f347475c807a5649296fb5a817c58dacef/google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05", size = 10842 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/8a/fe34d2f3f9470a27b01c9e76226965863f153d5fbe276f83608562e49c04/google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d", size = 9253 }, +] + +[[package]] +name = "google-genai" +version = "1.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "google-auth" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8d/cf/37ac8cd4752e28e547b8a52765fe48a2ada2d0d286ea03f46e4d8c69ff4f/google_genai-1.24.0.tar.gz", hash = "sha256:bc896e30ad26d05a2af3d17c2ba10ea214a94f1c0cdb93d5c004dc038774e75a", size = 226740 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/30/28/a35f64fc02e599808101617a21d447d241dadeba2aac1f4dc2d1179b8218/google_genai-1.24.0-py3-none-any.whl", hash = "sha256:98be8c51632576289ecc33cd84bcdaf4356ef0bef04ac7578660c49175af22b9", size = 226065 }, +] + +[[package]] +name = "google-generativeai" +version = "0.8.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-ai-generativelanguage" }, + { name = "google-api-core" }, + { name = "google-api-python-client" }, + { name = "google-auth" }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/6e/40/c42ff9ded9f09ec9392879a8e6538a00b2dc185e834a3392917626255419/google_generativeai-0.8.5-py3-none-any.whl", hash = "sha256:22b420817fb263f8ed520b33285f45976d5b21e904da32b80d4fd20c055123a2", size = 155427 }, +] + +[[package]] +name = "googleapis-common-protos" +version = "1.70.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/24/33db22342cf4a2ea27c9955e6713140fedd51e8b141b5ce5260897020f1a/googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257", size = 145903 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530 }, +] + [[package]] name = "greenlet" version = "3.1.1" @@ -658,6 +804,48 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ac/38/08cc303ddddc4b3d7c628c3039a61a3aae36c241ed01393d00c2fd663473/greenlet-3.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:411f015496fec93c1c8cd4e5238da364e1da7a124bcb293f085bf2860c32c6f6", size = 1142112 }, ] +[[package]] +name = "grpcio" +version = "1.73.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/79/e8/b43b851537da2e2f03fa8be1aef207e5cbfb1a2e014fbb6b40d24c177cd3/grpcio-1.73.1.tar.gz", hash = "sha256:7fce2cd1c0c1116cf3850564ebfc3264fba75d3c74a7414373f1238ea365ef87", size = 12730355 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/41/456caf570c55d5ac26f4c1f2db1f2ac1467d5bf3bcd660cba3e0a25b195f/grpcio-1.73.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:921b25618b084e75d424a9f8e6403bfeb7abef074bb6c3174701e0f2542debcf", size = 5334621 }, + { url = "https://files.pythonhosted.org/packages/2a/c2/9a15e179e49f235bb5e63b01590658c03747a43c9775e20c4e13ca04f4c4/grpcio-1.73.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:277b426a0ed341e8447fbf6c1d6b68c952adddf585ea4685aa563de0f03df887", size = 10601131 }, + { url = "https://files.pythonhosted.org/packages/0c/1d/1d39e90ef6348a0964caa7c5c4d05f3bae2c51ab429eb7d2e21198ac9b6d/grpcio-1.73.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:96c112333309493c10e118d92f04594f9055774757f5d101b39f8150f8c25582", size = 5759268 }, + { url = "https://files.pythonhosted.org/packages/8a/2b/2dfe9ae43de75616177bc576df4c36d6401e0959833b2e5b2d58d50c1f6b/grpcio-1.73.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f48e862aed925ae987eb7084409a80985de75243389dc9d9c271dd711e589918", size = 6409791 }, + { url = "https://files.pythonhosted.org/packages/6e/66/e8fe779b23b5a26d1b6949e5c70bc0a5fd08f61a6ec5ac7760d589229511/grpcio-1.73.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83a6c2cce218e28f5040429835fa34a29319071079e3169f9543c3fbeff166d2", size = 6003728 }, + { url = "https://files.pythonhosted.org/packages/a9/39/57a18fcef567784108c4fc3f5441cb9938ae5a51378505aafe81e8e15ecc/grpcio-1.73.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:65b0458a10b100d815a8426b1442bd17001fdb77ea13665b2f7dc9e8587fdc6b", size = 6103364 }, + { url = "https://files.pythonhosted.org/packages/c5/46/28919d2aa038712fc399d02fa83e998abd8c1f46c2680c5689deca06d1b2/grpcio-1.73.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:0a9f3ea8dce9eae9d7cb36827200133a72b37a63896e0e61a9d5ec7d61a59ab1", size = 6749194 }, + { url = "https://files.pythonhosted.org/packages/3d/56/3898526f1fad588c5d19a29ea0a3a4996fb4fa7d7c02dc1be0c9fd188b62/grpcio-1.73.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:de18769aea47f18e782bf6819a37c1c528914bfd5683b8782b9da356506190c8", size = 6283902 }, + { url = "https://files.pythonhosted.org/packages/dc/64/18b77b89c5870d8ea91818feb0c3ffb5b31b48d1b0ee3e0f0d539730fea3/grpcio-1.73.1-cp312-cp312-win32.whl", hash = "sha256:24e06a5319e33041e322d32c62b1e728f18ab8c9dbc91729a3d9f9e3ed336642", size = 3668687 }, + { url = "https://files.pythonhosted.org/packages/3c/52/302448ca6e52f2a77166b2e2ed75f5d08feca4f2145faf75cb768cccb25b/grpcio-1.73.1-cp312-cp312-win_amd64.whl", hash = "sha256:303c8135d8ab176f8038c14cc10d698ae1db9c480f2b2823f7a987aa2a4c5646", size = 4334887 }, + { url = "https://files.pythonhosted.org/packages/37/bf/4ca20d1acbefabcaba633ab17f4244cbbe8eca877df01517207bd6655914/grpcio-1.73.1-cp313-cp313-linux_armv7l.whl", hash = "sha256:b310824ab5092cf74750ebd8a8a8981c1810cb2b363210e70d06ef37ad80d4f9", size = 5335615 }, + { url = "https://files.pythonhosted.org/packages/75/ed/45c345f284abec5d4f6d77cbca9c52c39b554397eb7de7d2fcf440bcd049/grpcio-1.73.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:8f5a6df3fba31a3485096ac85b2e34b9666ffb0590df0cd044f58694e6a1f6b5", size = 10595497 }, + { url = "https://files.pythonhosted.org/packages/a4/75/bff2c2728018f546d812b755455014bc718f8cdcbf5c84f1f6e5494443a8/grpcio-1.73.1-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:052e28fe9c41357da42250a91926a3e2f74c046575c070b69659467ca5aa976b", size = 5765321 }, + { url = "https://files.pythonhosted.org/packages/70/3b/14e43158d3b81a38251b1d231dfb45a9b492d872102a919fbf7ba4ac20cd/grpcio-1.73.1-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c0bf15f629b1497436596b1cbddddfa3234273490229ca29561209778ebe182", size = 6415436 }, + { url = "https://files.pythonhosted.org/packages/e5/3f/81d9650ca40b54338336fd360f36773be8cb6c07c036e751d8996eb96598/grpcio-1.73.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ab860d5bfa788c5a021fba264802e2593688cd965d1374d31d2b1a34cacd854", size = 6007012 }, + { url = "https://files.pythonhosted.org/packages/55/f4/59edf5af68d684d0f4f7ad9462a418ac517201c238551529098c9aa28cb0/grpcio-1.73.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:ad1d958c31cc91ab050bd8a91355480b8e0683e21176522bacea225ce51163f2", size = 6105209 }, + { url = "https://files.pythonhosted.org/packages/e4/a8/700d034d5d0786a5ba14bfa9ce974ed4c976936c2748c2bd87aa50f69b36/grpcio-1.73.1-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:f43ffb3bd415c57224c7427bfb9e6c46a0b6e998754bfa0d00f408e1873dcbb5", size = 6753655 }, + { url = "https://files.pythonhosted.org/packages/1f/29/efbd4ac837c23bc48e34bbaf32bd429f0dc9ad7f80721cdb4622144c118c/grpcio-1.73.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:686231cdd03a8a8055f798b2b54b19428cdf18fa1549bee92249b43607c42668", size = 6287288 }, + { url = "https://files.pythonhosted.org/packages/d8/61/c6045d2ce16624bbe18b5d169c1a5ce4d6c3a47bc9d0e5c4fa6a50ed1239/grpcio-1.73.1-cp313-cp313-win32.whl", hash = "sha256:89018866a096e2ce21e05eabed1567479713ebe57b1db7cbb0f1e3b896793ba4", size = 3668151 }, + { url = "https://files.pythonhosted.org/packages/c2/d7/77ac689216daee10de318db5aa1b88d159432dc76a130948a56b3aa671a2/grpcio-1.73.1-cp313-cp313-win_amd64.whl", hash = "sha256:4a68f8c9966b94dff693670a5cf2b54888a48a5011c5d9ce2295a1a1465ee84f", size = 4335747 }, +] + +[[package]] +name = "grpcio-status" +version = "1.71.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "googleapis-common-protos" }, + { name = "grpcio" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/d1/b6e9877fedae3add1afdeae1f89d1927d296da9cf977eca0eb08fb8a460e/grpcio_status-1.71.2.tar.gz", hash = "sha256:c7a97e176df71cdc2c179cd1847d7fc86cca5832ad12e9798d7fed6b7a1aab50", size = 13677 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/67/58/317b0134129b556a93a3b0afe00ee675b5657f0155509e22fcb853bafe2d/grpcio_status-1.71.2-py3-none-any.whl", hash = "sha256:803c98cb6a8b7dc6dbb785b1111aed739f241ab5e9da0bba96888aa74704cfd3", size = 14424 }, +] + [[package]] name = "gunicorn" version = "23.0.0" @@ -692,6 +880,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/8d/f052b1e336bb2c1fc7ed1aaed898aa570c0b61a09707b108979d9fc6e308/httpcore-1.0.8-py3-none-any.whl", hash = "sha256:5254cf149bcb5f75e9d1b2b9f729ea4a4b883d1ad7379fc632b727cec23674be", size = 78732 }, ] +[[package]] +name = "httplib2" +version = "0.22.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyparsing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/ad/2371116b22d616c194aa25ec410c9c6c37f23599dcd590502b74db197584/httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81", size = 351116 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/6c/d2fbdaaa5959339d53ba38e94c123e4e84b8fbc4b84beb0e70d7c1608486/httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc", size = 96854 }, +] + [[package]] name = "httpx" version = "0.28.1" @@ -1040,6 +1240,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/d3/c3cb8f1d6ae3b37f83e1de806713a9b3642c5895f0215a62e1a4bd6e5e34/propcache-0.3.1-py3-none-any.whl", hash = "sha256:9a8ecf38de50a7f518c21568c80f985e776397b902f1ce0b01f799aba1608b40", size = 12376 }, ] +[[package]] +name = "proto-plus" +version = "1.26.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/ac/87285f15f7cce6d4a008f33f1757fb5a13611ea8914eb58c3d0d26243468/proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012", size = 56142 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/6d/280c4c2ce28b1593a19ad5239c8b826871fc6ec275c21afc8e1820108039/proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66", size = 50163 }, +] + +[[package]] +name = "protobuf" +version = "5.29.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/29/d09e70352e4e88c9c7a198d5645d7277811448d76c23b00345670f7c8a38/protobuf-5.29.5.tar.gz", hash = "sha256:bc1463bafd4b0929216c35f437a8e28731a2b7fe3d98bb77a600efced5a15c84", size = 425226 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/11/6e40e9fc5bba02988a214c07cf324595789ca7820160bfd1f8be96e48539/protobuf-5.29.5-cp310-abi3-win32.whl", hash = "sha256:3f1c6468a2cfd102ff4703976138844f78ebd1fb45f49011afc5139e9e283079", size = 422963 }, + { url = "https://files.pythonhosted.org/packages/81/7f/73cefb093e1a2a7c3ffd839e6f9fcafb7a427d300c7f8aef9c64405d8ac6/protobuf-5.29.5-cp310-abi3-win_amd64.whl", hash = "sha256:3f76e3a3675b4a4d867b52e4a5f5b78a2ef9565549d4037e06cf7b0942b1d3fc", size = 434818 }, + { url = "https://files.pythonhosted.org/packages/dd/73/10e1661c21f139f2c6ad9b23040ff36fee624310dc28fba20d33fdae124c/protobuf-5.29.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e38c5add5a311f2a6eb0340716ef9b039c1dfa428b28f25a7838ac329204a671", size = 418091 }, + { url = "https://files.pythonhosted.org/packages/6c/04/98f6f8cf5b07ab1294c13f34b4e69b3722bb609c5b701d6c169828f9f8aa/protobuf-5.29.5-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:fa18533a299d7ab6c55a238bf8629311439995f2e7eca5caaff08663606e9015", size = 319824 }, + { url = "https://files.pythonhosted.org/packages/85/e4/07c80521879c2d15f321465ac24c70efe2381378c00bf5e56a0f4fbac8cd/protobuf-5.29.5-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:63848923da3325e1bf7e9003d680ce6e14b07e55d0473253a690c3a8b8fd6e61", size = 319942 }, + { url = "https://files.pythonhosted.org/packages/7e/cc/7e77861000a0691aeea8f4566e5d3aa716f2b1dece4a24439437e41d3d25/protobuf-5.29.5-py3-none-any.whl", hash = "sha256:6cf42630262c59b2d8de33954443d94b746c952b01434fc58a417fdbd2e84bd5", size = 172823 }, +] + [[package]] name = "psycopg2-binary" version = "2.9.10" @@ -1080,6 +1306,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/1e/a94a8d635fa3ce4cfc7f506003548d0a2447ae76fd5ca53932970fe3053f/pyasn1-0.4.8-py2.py3-none-any.whl", hash = "sha256:39c7e2ec30515947ff4e87fb6f456dfc6e84857d34be479c9d4a4ba4bf46aa5d", size = 77145 }, ] +[[package]] +name = "pyasn1-modules" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/67/6afbf0d507f73c32d21084a79946bfcfca5fbc62a72057e9c23797a737c9/pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c", size = 310028 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/89/bc88a6711935ba795a679ea6ebee07e128050d6382eaa35a0a47c8032bdc/pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd", size = 181537 }, +] + [[package]] name = "pycodestyle" version = "2.14.0" @@ -1164,6 +1402,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/2f/81d580a0fb83baeb066698975cb14a618bdbed7720678566f1b046a95fe8/pyflakes-3.4.0-py2.py3-none-any.whl", hash = "sha256:f742a7dbd0d9cb9ea41e9a24a918996e8170c799fa528688d40dd582c8265f4f", size = 63551 }, ] +[[package]] +name = "pyparsing" +version = "3.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/22/f1129e69d94ffff626bdb5c835506b3a5b4f3d070f17ea295e12c2c6f60f/pyparsing-3.2.3.tar.gz", hash = "sha256:b9c13f1ab8b3b542f72e28f634bad4de758ab3ce4546e4301970ad6fa77c38be", size = 1088608 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/e7/df2285f3d08fee213f2d041540fa4fc9ca6c2d44cf36d3a035bf2a8d2bcc/pyparsing-3.2.3-py3-none-any.whl", hash = "sha256:a749938e02d6fd0b59b356ca504a24982314bb090c383e3cf201c95ef7e2bfcf", size = 111120 }, +] + [[package]] name = "pytest" version = "8.3.5" @@ -1429,6 +1676,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/f3/5633e45bc01825c4464b6b1e98e05052e532139e827c4ea8c54f5eafb022/svix-1.67.0-py3-none-any.whl", hash = "sha256:4f195bea0ac7c33c54f29bb486e3814e9c50123be303bfba5064d1e607274668", size = 95009 }, ] +[[package]] +name = "tenacity" +version = "8.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/4d/6a19536c50b849338fcbe9290d562b52cbdcf30d8963d3588a68a4107df1/tenacity-8.5.0.tar.gz", hash = "sha256:8bc6c0c8a09b31e6cad13c47afbed1a567518250a9a171418582ed8d9c20ca78", size = 47309 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/3f/8ba87d9e287b9d385a02a7114ddcef61b26f86411e121c9003eb509a1773/tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687", size = 28165 }, +] + +[[package]] +name = "tqdm" +version = "4.67.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540 }, +] + [[package]] name = "types-deprecated" version = "1.2.15.20250304" @@ -1468,6 +1736,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/08/aa4fdfb71f7de5176385bd9e90852eaf6b5d622735020ad600f2bab54385/typing_inspection-0.4.0-py3-none-any.whl", hash = "sha256:50e72559fcd2a6367a19f7a7e610e6afcb9fac940c650290eed893d61386832f", size = 14125 }, ] +[[package]] +name = "uritemplate" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/60/f174043244c5306c9988380d2cb10009f91563fc4b31293d27e17201af56/uritemplate-4.2.0.tar.gz", hash = "sha256:480c2ed180878955863323eea31b0ede668795de182617fef9c6ca09e6ec9d0e", size = 33267 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/99/3ae339466c9183ea5b8ae87b34c0b897eda475d2aec2307cae60e5cd4f29/uritemplate-4.2.0-py3-none-any.whl", hash = "sha256:962201ba1c4edcab02e60f9a0d3821e82dfc5d2d6662a21abd533879bdb8a686", size = 11488 }, +] + [[package]] name = "urllib3" version = "2.3.0" @@ -1504,6 +1781,37 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/eb/c6db6e3001d58c6a9e67c74bb7b4206767caa3ccc28c6b9eaf4c23fb4e34/virtualenv-20.29.3-py3-none-any.whl", hash = "sha256:3e3d00f5807e83b234dfb6122bf37cfadf4be216c53a49ac059d02414f819170", size = 4301458 }, ] +[[package]] +name = "websockets" +version = "15.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/21/e6/26d09fab466b7ca9c7737474c52be4f76a40301b08362eb2dbc19dcc16c1/websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee", size = 177016 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/6b/4545a0d843594f5d0771e86463606a3988b5a09ca5123136f8a76580dd63/websockets-15.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3e90baa811a5d73f3ca0bcbf32064d663ed81318ab225ee4f427ad4e26e5aff3", size = 175437 }, + { url = "https://files.pythonhosted.org/packages/f4/71/809a0f5f6a06522af902e0f2ea2757f71ead94610010cf570ab5c98e99ed/websockets-15.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:592f1a9fe869c778694f0aa806ba0374e97648ab57936f092fd9d87f8bc03665", size = 173096 }, + { url = "https://files.pythonhosted.org/packages/3d/69/1a681dd6f02180916f116894181eab8b2e25b31e484c5d0eae637ec01f7c/websockets-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0701bc3cfcb9164d04a14b149fd74be7347a530ad3bbf15ab2c678a2cd3dd9a2", size = 173332 }, + { url = "https://files.pythonhosted.org/packages/a6/02/0073b3952f5bce97eafbb35757f8d0d54812b6174ed8dd952aa08429bcc3/websockets-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8b56bdcdb4505c8078cb6c7157d9811a85790f2f2b3632c7d1462ab5783d215", size = 183152 }, + { url = "https://files.pythonhosted.org/packages/74/45/c205c8480eafd114b428284840da0b1be9ffd0e4f87338dc95dc6ff961a1/websockets-15.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0af68c55afbd5f07986df82831c7bff04846928ea8d1fd7f30052638788bc9b5", size = 182096 }, + { url = "https://files.pythonhosted.org/packages/14/8f/aa61f528fba38578ec553c145857a181384c72b98156f858ca5c8e82d9d3/websockets-15.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64dee438fed052b52e4f98f76c5790513235efaa1ef7f3f2192c392cd7c91b65", size = 182523 }, + { url = "https://files.pythonhosted.org/packages/ec/6d/0267396610add5bc0d0d3e77f546d4cd287200804fe02323797de77dbce9/websockets-15.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d5f6b181bb38171a8ad1d6aa58a67a6aa9d4b38d0f8c5f496b9e42561dfc62fe", size = 182790 }, + { url = "https://files.pythonhosted.org/packages/02/05/c68c5adbf679cf610ae2f74a9b871ae84564462955d991178f95a1ddb7dd/websockets-15.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5d54b09eba2bada6011aea5375542a157637b91029687eb4fdb2dab11059c1b4", size = 182165 }, + { url = "https://files.pythonhosted.org/packages/29/93/bb672df7b2f5faac89761cb5fa34f5cec45a4026c383a4b5761c6cea5c16/websockets-15.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3be571a8b5afed347da347bfcf27ba12b069d9d7f42cb8c7028b5e98bbb12597", size = 182160 }, + { url = "https://files.pythonhosted.org/packages/ff/83/de1f7709376dc3ca9b7eeb4b9a07b4526b14876b6d372a4dc62312bebee0/websockets-15.0.1-cp312-cp312-win32.whl", hash = "sha256:c338ffa0520bdb12fbc527265235639fb76e7bc7faafbb93f6ba80d9c06578a9", size = 176395 }, + { url = "https://files.pythonhosted.org/packages/7d/71/abf2ebc3bbfa40f391ce1428c7168fb20582d0ff57019b69ea20fa698043/websockets-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcd5cf9e305d7b8338754470cf69cf81f420459dbae8a3b40cee57417f4614a7", size = 176841 }, + { url = "https://files.pythonhosted.org/packages/cb/9f/51f0cf64471a9d2b4d0fc6c534f323b664e7095640c34562f5182e5a7195/websockets-15.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee443ef070bb3b6ed74514f5efaa37a252af57c90eb33b956d35c8e9c10a1931", size = 175440 }, + { url = "https://files.pythonhosted.org/packages/8a/05/aa116ec9943c718905997412c5989f7ed671bc0188ee2ba89520e8765d7b/websockets-15.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5a939de6b7b4e18ca683218320fc67ea886038265fd1ed30173f5ce3f8e85675", size = 173098 }, + { url = "https://files.pythonhosted.org/packages/ff/0b/33cef55ff24f2d92924923c99926dcce78e7bd922d649467f0eda8368923/websockets-15.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:746ee8dba912cd6fc889a8147168991d50ed70447bf18bcda7039f7d2e3d9151", size = 173329 }, + { url = "https://files.pythonhosted.org/packages/31/1d/063b25dcc01faa8fada1469bdf769de3768b7044eac9d41f734fd7b6ad6d/websockets-15.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:595b6c3969023ecf9041b2936ac3827e4623bfa3ccf007575f04c5a6aa318c22", size = 183111 }, + { url = "https://files.pythonhosted.org/packages/93/53/9a87ee494a51bf63e4ec9241c1ccc4f7c2f45fff85d5bde2ff74fcb68b9e/websockets-15.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c714d2fc58b5ca3e285461a4cc0c9a66bd0e24c5da9911e30158286c9b5be7f", size = 182054 }, + { url = "https://files.pythonhosted.org/packages/ff/b2/83a6ddf56cdcbad4e3d841fcc55d6ba7d19aeb89c50f24dd7e859ec0805f/websockets-15.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f3c1e2ab208db911594ae5b4f79addeb3501604a165019dd221c0bdcabe4db8", size = 182496 }, + { url = "https://files.pythonhosted.org/packages/98/41/e7038944ed0abf34c45aa4635ba28136f06052e08fc2168520bb8b25149f/websockets-15.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:229cf1d3ca6c1804400b0a9790dc66528e08a6a1feec0d5040e8b9eb14422375", size = 182829 }, + { url = "https://files.pythonhosted.org/packages/e0/17/de15b6158680c7623c6ef0db361da965ab25d813ae54fcfeae2e5b9ef910/websockets-15.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:756c56e867a90fb00177d530dca4b097dd753cde348448a1012ed6c5131f8b7d", size = 182217 }, + { url = "https://files.pythonhosted.org/packages/33/2b/1f168cb6041853eef0362fb9554c3824367c5560cbdaad89ac40f8c2edfc/websockets-15.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:558d023b3df0bffe50a04e710bc87742de35060580a293c2a984299ed83bc4e4", size = 182195 }, + { url = "https://files.pythonhosted.org/packages/86/eb/20b6cdf273913d0ad05a6a14aed4b9a85591c18a987a3d47f20fa13dcc47/websockets-15.0.1-cp313-cp313-win32.whl", hash = "sha256:ba9e56e8ceeeedb2e080147ba85ffcd5cd0711b89576b83784d8605a7df455fa", size = 176393 }, + { url = "https://files.pythonhosted.org/packages/1b/6c/c65773d6cab416a64d191d6ee8a8b1c68a09970ea6909d16965d26bfed1e/websockets-15.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:e09473f095a819042ecb2ab9465aee615bd9c2028e4ef7d933600a8401c79561", size = 176837 }, + { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743 }, +] + [[package]] name = "win32-setctime" version = "1.2.0" From 0714c3150c39260e414bcfba2169d1cf29e1ff17 Mon Sep 17 00:00:00 2001 From: Dokujaa Date: Thu, 17 Jul 2025 22:37:34 -0400 Subject: [PATCH 09/10] feat: implement OAuth2 token caching for Vertex AI authentication - Add OAuth2 token caching functions to async_cache.py and cache.py - Create async_oauth_token_cache with 55-minute TTL (5-min safety buffer) - Update vertex_authentication method to check cache first before token refresh - Add OAuth2 token cache statistics tracking - Cache key format: token:{api_key} - Performance optimization: reduce unnecessary token refresh calls --- app/core/async_cache.py | 41 ++++++++++++++++++++++-- app/core/cache.py | 41 ++++++++++++++++++++++-- app/services/providers/vertex_adapter.py | 28 ++++++++++++++++ 3 files changed, 106 insertions(+), 4 deletions(-) diff --git a/app/core/async_cache.py b/app/core/async_cache.py index c3ba536..b92781b 100644 --- a/app/core/async_cache.py +++ b/app/core/async_cache.py @@ -156,6 +156,8 @@ async def wrapper(*args, **kwargs): async_provider_service_cache: "AsyncCache" = _AsyncBackend( ttl_seconds=3600 ) # 1-hour TTL +# OAuth2 token caching (55-min TTL with 5-min safety buffer before 1-hour token expiry) +async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=3300) # 55-min TTL # User-specific functions @@ -336,6 +338,33 @@ async def invalidate_provider_service_cache_async(user_id: int) -> None: ) +# OAuth2 token caching functions +async def get_cached_oauth_token_async(api_key: str) -> dict[str, Any] | None: + """Get a cached OAuth2 token by API key asynchronously""" + if not api_key: + return None + cached_data = await async_oauth_token_cache.get(f"token:{api_key}") + if cached_data: + return cached_data + return None + + +async def cache_oauth_token_async(api_key: str, token_data: dict[str, Any]) -> None: + """Cache an OAuth2 token by API key asynchronously""" + if not api_key or not token_data: + return + await async_oauth_token_cache.set(f"token:{api_key}", token_data) + + +async def invalidate_oauth_token_cache_async(api_key: str) -> None: + """Invalidate OAuth2 token cache for a specific API key asynchronously""" + if not api_key: + return + await async_oauth_token_cache.delete(f"token:{api_key}") + if DEBUG_CACHE: + logger.debug(f"Cache: Invalidated OAuth2 token cache for key: {api_key[:8]}...") + + async def invalidate_provider_models_cache_async(provider_name: str) -> None: """Invalidate model cache for a specific provider asynchronously""" if not provider_name: @@ -387,6 +416,7 @@ async def invalidate_all_caches_async() -> None: """Invalidate all caches in the system asynchronously""" await async_user_cache.clear() await async_provider_service_cache.clear() + await async_oauth_token_cache.clear() if DEBUG_CACHE: logger.debug("Cache: Invalidated all caches") @@ -429,6 +459,7 @@ async def get_cache_stats_async() -> dict[str, dict[str, Any]]: return { "user_cache": await async_user_cache.stats(), "provider_service_cache": await async_provider_service_cache.stats(), + "oauth_token_cache": await async_oauth_token_cache.stats(), } @@ -437,9 +468,15 @@ async def monitor_cache_performance_async() -> dict[str, Any]: stats = await get_cache_stats_async() # Calculate overall hit rates - total_hits = stats["user_cache"]["hits"] + stats["provider_service_cache"]["hits"] + total_hits = ( + stats["user_cache"]["hits"] + + stats["provider_service_cache"]["hits"] + + stats["oauth_token_cache"]["hits"] + ) total_requests = ( - stats["user_cache"]["total"] + stats["provider_service_cache"]["total"] + stats["user_cache"]["total"] + + stats["provider_service_cache"]["total"] + + stats["oauth_token_cache"]["total"] ) overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0 diff --git a/app/core/cache.py b/app/core/cache.py index 7fe9fbe..14be802 100644 --- a/app/core/cache.py +++ b/app/core/cache.py @@ -119,6 +119,8 @@ def stats(self) -> dict[str, Any]: # Expose the global cache instances user_cache: "Cache" = _CacheBackend(ttl_seconds=300) # 5-minute TTL for users provider_service_cache: "Cache" = _CacheBackend(ttl_seconds=3600) # 1-hour TTL +# OAuth2 token caching (55-min TTL with 5-min safety buffer before 1-hour token expiry) +oauth_token_cache: "Cache" = _CacheBackend(ttl_seconds=3300) # 55-min TTL def cached(cache_instance: Cache, key_func: Callable[[Any], str] = None): @@ -238,6 +240,33 @@ def invalidate_provider_service_cache(user_id: int) -> None: ) +# OAuth2 token caching functions +def get_cached_oauth_token(api_key: str) -> dict[str, Any] | None: + """Get a cached OAuth2 token by API key""" + if not api_key: + return None + cached_data = oauth_token_cache.get(f"token:{api_key}") + if cached_data: + return cached_data + return None + + +def cache_oauth_token(api_key: str, token_data: dict[str, Any]) -> None: + """Cache an OAuth2 token by API key""" + if not api_key or not token_data: + return + oauth_token_cache.set(f"token:{api_key}", token_data) + + +def invalidate_oauth_token_cache(api_key: str) -> None: + """Invalidate OAuth2 token cache for a specific API key""" + if not api_key: + return + oauth_token_cache.delete(f"token:{api_key}") + if DEBUG_CACHE: + logger.debug(f"Cache: Invalidated OAuth2 token cache for key: {api_key[:8]}...") + + def invalidate_user_cache_by_id(user_id: int) -> None: """Invalidate all cache entries for a specific user ID""" if not user_id: @@ -325,6 +354,7 @@ def invalidate_all_caches() -> None: """Invalidate all caches in the system""" user_cache.clear() provider_service_cache.clear() + oauth_token_cache.clear() if DEBUG_CACHE: logger.debug("Cache: Invalidated all caches") @@ -362,6 +392,7 @@ def get_cache_stats() -> dict[str, dict[str, Any]]: return { "user_cache": user_cache.stats(), "provider_service_cache": provider_service_cache.stats(), + "oauth_token_cache": oauth_token_cache.stats(), } @@ -370,9 +401,15 @@ def monitor_cache_performance() -> dict[str, Any]: stats = get_cache_stats() # Calculate overall hit rates - total_hits = stats["user_cache"]["hits"] + stats["provider_service_cache"]["hits"] + total_hits = ( + stats["user_cache"]["hits"] + + stats["provider_service_cache"]["hits"] + + stats["oauth_token_cache"]["hits"] + ) total_requests = ( - stats["user_cache"]["total"] + stats["provider_service_cache"]["total"] + stats["user_cache"]["total"] + + stats["provider_service_cache"]["total"] + + stats["oauth_token_cache"]["total"] ) overall_hit_rate = total_hits / total_requests if total_requests > 0 else 0.0 diff --git a/app/services/providers/vertex_adapter.py b/app/services/providers/vertex_adapter.py index 2023c71..b59ef53 100644 --- a/app/services/providers/vertex_adapter.py +++ b/app/services/providers/vertex_adapter.py @@ -1,12 +1,14 @@ import asyncio import json from collections.abc import AsyncGenerator +from datetime import datetime, timezone from typing import Any import aiohttp from google.oauth2 import service_account from google.auth.transport.requests import Request from app.exceptions.exceptions import ProviderAuthenticationException, InvalidProviderConfigException, InvalidProviderAPIKeyException, ProviderAPIException +from app.core.async_cache import get_cached_oauth_token_async, cache_oauth_token_async, invalidate_oauth_token_cache_async from app.core.logger import get_logger from .base import ProviderAdapter @@ -125,12 +127,38 @@ async def vertex_authentication(self, api_key: str) -> str: # validate api key self.parse_api_key(api_key) + # check cache first for existing valid token + cached_token = await get_cached_oauth_token_async(api_key) + if cached_token: + token_str = cached_token.get("token") + expiry_str = cached_token.get("expiry") + if token_str and expiry_str: + try: + expiry = datetime.fromisoformat(expiry_str) + # Make expiry timezone-aware if it's naive (Google credentials are UTC) + if expiry.tzinfo is None: + expiry = expiry.replace(tzinfo=timezone.utc) + if expiry > datetime.now(timezone.utc): + return token_str + except (ValueError, TypeError): + # Invalid cached token, clear it and continue to refresh + await invalidate_oauth_token_cache_async(api_key) + # load credentials within scope try: credentials = service_account.Credentials.from_service_account_info(self.cred_json, scopes=["https://www.googleapis.com/auth/cloud-platform"]) # refresh token - run in thread pool to avoid blocking await asyncio.to_thread(credentials.refresh, Request()) + + # cache the token with expiry information + if credentials.token and credentials.expiry: + token_data = { + "token": credentials.token, + "expiry": credentials.expiry.isoformat() + } + await cache_oauth_token_async(api_key, token_data) + return credentials.token except Exception as e: logger.error(f"Error authenticating with Vertex API: {e}") From 2c47e8e85f74638048f8e64fcee89f3c5232dd95 Mon Sep 17 00:00:00 2001 From: Dokujaa Date: Fri, 18 Jul 2025 15:45:24 -0400 Subject: [PATCH 10/10] feat: implement hash-based OAuth2 token caching with smart cleanup * Replace fixed TTL with token's native expires_at timestamp validation * Use SHA-256 hashed cache keys for long JSON service account credentials * Add opportunistic cleanup to prevent expired token memory leaks * Standardize token structure with access_token, expires_at, token_type fields * Simplify Vertex AI authentication using service account re-authentication * Move imports to top-level for better code quality Addresses PR feedback on cache key security, TTL complexity, and memory management. --- app/core/async_cache.py | 84 +++++++++++++++++++++--- app/services/providers/vertex_adapter.py | 25 +++---- 2 files changed, 85 insertions(+), 24 deletions(-) diff --git a/app/core/async_cache.py b/app/core/async_cache.py index b92781b..3775e83 100644 --- a/app/core/async_cache.py +++ b/app/core/async_cache.py @@ -5,6 +5,7 @@ import asyncio import functools +import hashlib import os import time from collections.abc import Callable @@ -156,8 +157,8 @@ async def wrapper(*args, **kwargs): async_provider_service_cache: "AsyncCache" = _AsyncBackend( ttl_seconds=3600 ) # 1-hour TTL -# OAuth2 token caching (55-min TTL with 5-min safety buffer before 1-hour token expiry) -async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=3300) # 55-min TTL +# OAuth2 token caching (no TTL - uses token's own expiration with smart cleanup) +async_oauth_token_cache: "AsyncCache" = _AsyncBackend(ttl_seconds=None) # User-specific functions @@ -343,26 +344,91 @@ async def get_cached_oauth_token_async(api_key: str) -> dict[str, Any] | None: """Get a cached OAuth2 token by API key asynchronously""" if not api_key: return None - cached_data = await async_oauth_token_cache.get(f"token:{api_key}") - if cached_data: - return cached_data - return None + + cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}" + cached_data = await async_oauth_token_cache.get(cache_key) + if not cached_data: + return None + + expires_at = cached_data.get("expires_at") + if not expires_at: + await async_oauth_token_cache.delete(cache_key) + return None + + current_time = time.time() + if expires_at <= current_time: + await async_oauth_token_cache.delete(cache_key) + await _opportunistic_cleanup(current_time, max_items=2) + return None + + return cached_data async def cache_oauth_token_async(api_key: str, token_data: dict[str, Any]) -> None: """Cache an OAuth2 token by API key asynchronously""" if not api_key or not token_data: return - await async_oauth_token_cache.set(f"token:{api_key}", token_data) + + cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}" + if "expires_at" not in token_data: + logger.warning("OAuth token cached without expires_at - skipping") + return + + await async_oauth_token_cache.set(cache_key, token_data) async def invalidate_oauth_token_cache_async(api_key: str) -> None: """Invalidate OAuth2 token cache for a specific API key asynchronously""" if not api_key: return - await async_oauth_token_cache.delete(f"token:{api_key}") + + cache_key = f"token:{hashlib.sha256(api_key.encode()).hexdigest()}" + await async_oauth_token_cache.delete(cache_key) if DEBUG_CACHE: - logger.debug(f"Cache: Invalidated OAuth2 token cache for key: {api_key[:8]}...") + logger.debug(f"Cache: Invalidated OAuth2 token cache for key: {cache_key[:16]}...") + + +async def _opportunistic_cleanup(current_time: float, max_items: int = 2) -> None: + """Opportunistically clean up expired OAuth tokens from cache""" + cleaned = 0 + + # Case 1: in-memory backend exposes .cache dict + if hasattr(async_oauth_token_cache, "cache"): + async with async_oauth_token_cache.lock: + for key, value in list(async_oauth_token_cache.cache.items()): + if cleaned >= max_items: + break + if key.startswith("token:"): + expires_at = value.get("expires_at") + if expires_at and expires_at <= current_time: + await async_oauth_token_cache.delete(key) + cleaned += 1 + if DEBUG_CACHE: + logger.debug(f"Cache: Cleaned up expired token: {key[:16]}...") + + # Case 2: Redis backend + elif hasattr(async_oauth_token_cache, "client"): + try: + pattern = f"{os.getenv('REDIS_PREFIX', 'forge')}:token:*" + async for redis_key in async_oauth_token_cache.client.scan_iter(match=pattern, count=10): + if cleaned >= max_items: + break + key_str = redis_key.decode() if isinstance(redis_key, bytes) else redis_key + internal_key = key_str.split(":", 1)[-1] + cached_data = await async_oauth_token_cache.get(internal_key) + if cached_data: + expires_at = cached_data.get("expires_at") + if expires_at and expires_at <= current_time: + await async_oauth_token_cache.delete(internal_key) + cleaned += 1 + if DEBUG_CACHE: + logger.debug(f"Cache: Cleaned up expired token: {internal_key[:16]}...") + except Exception as exc: + if DEBUG_CACHE: + logger.warning(f"Failed to perform opportunistic cleanup: {exc}") + + if DEBUG_CACHE and cleaned > 0: + logger.debug(f"Cache: Opportunistic cleanup removed {cleaned} expired tokens") async def invalidate_provider_models_cache_async(provider_name: str) -> None: diff --git a/app/services/providers/vertex_adapter.py b/app/services/providers/vertex_adapter.py index b59ef53..38f40d7 100644 --- a/app/services/providers/vertex_adapter.py +++ b/app/services/providers/vertex_adapter.py @@ -1,5 +1,6 @@ import asyncio import json +import time from collections.abc import AsyncGenerator from datetime import datetime, timezone from typing import Any @@ -130,19 +131,9 @@ async def vertex_authentication(self, api_key: str) -> str: # check cache first for existing valid token cached_token = await get_cached_oauth_token_async(api_key) if cached_token: - token_str = cached_token.get("token") - expiry_str = cached_token.get("expiry") - if token_str and expiry_str: - try: - expiry = datetime.fromisoformat(expiry_str) - # Make expiry timezone-aware if it's naive (Google credentials are UTC) - if expiry.tzinfo is None: - expiry = expiry.replace(tzinfo=timezone.utc) - if expiry > datetime.now(timezone.utc): - return token_str - except (ValueError, TypeError): - # Invalid cached token, clear it and continue to refresh - await invalidate_oauth_token_cache_async(api_key) + access_token = cached_token.get("access_token") + if access_token: + return access_token # load credentials within scope try: @@ -154,8 +145,12 @@ async def vertex_authentication(self, api_key: str) -> str: # cache the token with expiry information if credentials.token and credentials.expiry: token_data = { - "token": credentials.token, - "expiry": credentials.expiry.isoformat() + "access_token": credentials.token, + "token_type": "Bearer", + "expires_at": credentials.expiry.timestamp(), # Unix timestamp + "scope": "https://www.googleapis.com/auth/cloud-platform", + "cached_at": time.time(), # For debugging + "provider": "vertex" # Helpful for multi-provider systems } await cache_oauth_token_async(api_key, token_data)