diff --git a/app/api/schemas/openai.py b/app/api/schemas/openai.py index a9061a9..60d8d20 100644 --- a/app/api/schemas/openai.py +++ b/app/api/schemas/openai.py @@ -2,6 +2,11 @@ from pydantic import BaseModel +from app.core.logger import get_logger +from app.exceptions.exceptions import InvalidCompletionRequestException + +logger = get_logger(name="openai_schemas") + class OpenAIContentImageUrlModel(BaseModel): url: str @@ -21,17 +26,35 @@ class OpenAIContentModel(BaseModel): def __init__(self, **data: Any): super().__init__(**data) if self.type not in ["text", "image_url", "input_audio"]: - raise ValueError( - f"Invalid type: {self.type}. Must be one of: text, image_url, input_audio" + error_message = f"Invalid type: {self.type}. Must be one of: text, image_url, input_audio" + logger.error(error_message) + raise InvalidCompletionRequestException( + provider_name="openai", + error=ValueError(error_message) ) # Validate that the appropriate field is set based on type if self.type == "text" and self.text is None: - raise ValueError("text field must be set when type is 'text'") + error_message = "text field must be set when type is 'text'" + logger.error(error_message) + raise InvalidCompletionRequestException( + provider_name="openai", + error=ValueError(error_message) + ) if self.type == "image_url" and self.image_url is None: - raise ValueError("image_url field must be set when type is 'image_url'") + error_message = "image_url field must be set when type is 'image_url'" + logger.error(error_message) + raise InvalidCompletionRequestException( + provider_name="openai", + error=ValueError(error_message) + ) if self.type == "input_audio" and self.input_audio is None: - raise ValueError("input_audio field must be set when type is 'input_audio'") + error_message = "input_audio field must be set when type is 'input_audio'" + logger.error(error_message) + raise InvalidCompletionRequestException( + provider_name="openai", + error=ValueError(error_message) + ) class ChatMessage(BaseModel): diff --git a/app/exceptions/exceptions.py b/app/exceptions/exceptions.py index 5e6ebac..112527f 100644 --- a/app/exceptions/exceptions.py +++ b/app/exceptions/exceptions.py @@ -1,6 +1,83 @@ -class InvalidProviderException(Exception): +class BaseForgeException(Exception): + pass + +class InvalidProviderException(BaseForgeException): """Exception raised when a provider is invalid.""" def __init__(self, identifier: str): self.identifier = identifier - super().__init__(f"Provider {identifier} is invalid or failed to extract provider info from model_id {identifier}") \ No newline at end of file + super().__init__(f"Provider {identifier} is invalid or failed to extract provider info from model_id {identifier}") + + +class ProviderAuthenticationException(BaseForgeException): + """Exception raised when a provider authentication fails.""" + + def __init__(self, provider_name: str, error: Exception): + self.provider_name = provider_name + self.error = error + super().__init__(f"Provider {provider_name} authentication failed: {error}") + + +class BaseInvalidProviderSetupException(BaseForgeException): + """Exception raised when a provider setup is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + self.provider_name = provider_name + self.error = error + super().__init__(f"Provider {provider_name} setup is invalid: {error}") + +class InvalidProviderConfigException(BaseInvalidProviderSetupException): + """Exception raised when a provider config is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + super().__init__(provider_name, error) + +class InvalidProviderAPIKeyException(BaseInvalidProviderSetupException): + """Exception raised when a provider API key is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + super().__init__(provider_name, error) + +class ProviderAPIException(BaseForgeException): + """Exception raised when a provider API error occurs.""" + + def __init__(self, provider_name: str, error_code: int, error_message: str): + super().__init__(f"Provider {provider_name} API error: {error_code} {error_message}") + + +class BaseInvalidRequestException(BaseForgeException): + """Exception raised when a request is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + self.provider_name = provider_name + self.error = error + super().__init__(f"Provider {provider_name} request is invalid: {error}") + +class InvalidCompletionRequestException(BaseInvalidRequestException): + """Exception raised when a completion request is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + self.provider_name = provider_name + self.error = error + super().__init__(f"Provider {provider_name} completion request is invalid: {error}") + +class InvalidEmbeddingsRequestException(BaseInvalidRequestException): + """Exception raised when a embeddings request is invalid.""" + + def __init__(self, provider_name: str, error: Exception): + self.provider_name = provider_name + self.error = error + super().__init__(f"Provider {provider_name} embeddings request is invalid: {error}") + +class BaseInvalidForgeKeyException(BaseForgeException): + """Exception raised when a Forge key is invalid.""" + + def __init__(self, error: Exception): + self.error = error + super().__init__(f"Forge key is invalid: {error}") + + +class InvalidForgeKeyException(BaseInvalidForgeKeyException): + """Exception raised when a Forge key is invalid.""" + def __init__(self, error: Exception): + super().__init__(error) \ No newline at end of file diff --git a/app/main.py b/app/main.py index 69b6047..85364a6 100644 --- a/app/main.py +++ b/app/main.py @@ -3,7 +3,7 @@ from collections.abc import Callable from dotenv import load_dotenv -from fastapi import APIRouter, FastAPI, Request +from fastapi import APIRouter, FastAPI, Request, HTTPException from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.base import BaseHTTPMiddleware @@ -20,6 +20,8 @@ from app.core.database import engine from app.core.logger import get_logger from app.models.base import Base +from app.exceptions.exceptions import ProviderAuthenticationException, InvalidProviderException, BaseInvalidProviderSetupException, \ + ProviderAPIException, BaseInvalidRequestException, BaseInvalidForgeKeyException load_dotenv() @@ -77,6 +79,65 @@ async def dispatch(self, request: Request, call_next: Callable): openapi_url="/openapi.json" if not is_production else None, ) +### Exception handlers block ### + +# Add exception handler for ProviderAuthenticationException +@app.exception_handler(ProviderAuthenticationException) +async def provider_authentication_exception_handler(request: Request, exc: ProviderAuthenticationException): + return HTTPException( + status_code=401, + detail=f"Authentication failed for provider {exc.provider_name}" + ) + +# Add exception handler for InvalidProviderException +@app.exception_handler(InvalidProviderException) +async def invalid_provider_exception_handler(request: Request, exc: InvalidProviderException): + return HTTPException( + status_code=400, + detail=f"{str(exc)}. Please verify your provider and model details by calling the /models endpoint or visiting https://tensorblock.co/api-docs/model-ids, and ensure you’re using a valid provider name, model name, and model ID." + ) + +# Add exception handler for BaseInvalidProviderSetupException +@app.exception_handler(BaseInvalidProviderSetupException) +async def base_invalid_provider_setup_exception_handler(request: Request, exc: BaseInvalidProviderSetupException): + return HTTPException( + status_code=400, + detail=str(exc) + ) + +# Add exception handler for ProviderAPIException +@app.exception_handler(ProviderAPIException) +async def provider_api_exception_handler(request: Request, exc: ProviderAPIException): + return HTTPException( + status_code=exc.error_code, + detail=f"Provider API error: {exc.provider_name} {exc.error_code} {exc.error_message}" + ) + +# Add exception handler for BaseInvalidRequestException +@app.exception_handler(BaseInvalidRequestException) +async def base_invalid_request_exception_handler(request: Request, exc: BaseInvalidRequestException): + return HTTPException( + status_code=400, + detail=str(exc) + ) + +# Add exception handler for BaseInvalidForgeKeyException +@app.exception_handler(BaseInvalidForgeKeyException) +async def base_invalid_forge_key_exception_handler(request: Request, exc: BaseInvalidForgeKeyException): + return HTTPException( + status_code=401, + detail=f"Invalid Forge key: {exc.error}" + ) + +# Add exception handler for NotImplementedError +@app.exception_handler(NotImplementedError) +async def not_implemented_error_handler(request: Request, exc: NotImplementedError): + return HTTPException( + status_code=404, + detail=f"Not implemented: {exc}" + ) +### Exception handlers block ends ### + # Middleware to log slow requests @app.middleware("http") diff --git a/app/services/provider_service.py b/app/services/provider_service.py index 6793283..61688e3 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -15,7 +15,7 @@ ) from app.core.logger import get_logger from app.core.security import decrypt_api_key, encrypt_api_key -from app.exceptions.exceptions import InvalidProviderException +from app.exceptions.exceptions import InvalidProviderException, BaseInvalidRequestException, InvalidForgeKeyException from app.models.user import User from app.services.usage_stats_service import UsageStatsService @@ -322,6 +322,7 @@ def _get_provider_info_with_prefix( ) if not matching_provider: + logger.error(f"No matching provider found for {original_model}") raise InvalidProviderException(original_model) provider_data = self.provider_keys[matching_provider] @@ -358,6 +359,7 @@ def _find_provider_for_unprefixed_model( provider_data.get("base_url"), ) + logger.error(f"No matching provider found for {model}") raise InvalidProviderException(model) def _get_provider_info(self, model: str) -> tuple[str, str, str | None]: @@ -365,9 +367,9 @@ def _get_provider_info(self, model: str) -> tuple[str, str, str | None]: Determine the provider based on the model name. """ if not self._keys_loaded: - raise RuntimeError( - "Provider keys must be loaded before calling _get_provider_info. Call _load_provider_keys_async() first." - ) + error_message = "Provider keys must be loaded before calling _get_provider_info. Call _load_provider_keys_async() first." + logger.error(error_message) + raise RuntimeError(error_message) provider_name, model_name_no_prefix = self._extract_provider_name_prefix(model) @@ -485,25 +487,23 @@ async def process_request( model = payload.get("model") if not model: - raise ValueError("Model is required") - - try: - provider_name, actual_model, base_url = self._get_provider_info(model) - - # Enforce scope restriction (case-insensitive). - if allowed_provider_names is not None and ( - provider_name.lower() not in {p.lower() for p in allowed_provider_names} - ): - raise ValueError( - f"API key is not permitted to use provider '{provider_name}'." - ) - except ValueError as e: - # Use parameterized logging to avoid issues if the error message contains braces - logger.error("Error getting provider info for model {}: {}", model, str(e)) - raise ValueError( - f"Invalid model ID: {model}. Please check your model configuration." + error_message = "Model is required" + logger.error(error_message) + raise BaseInvalidRequestException( + provider_name="unknown", + error=ValueError(error_message) ) + provider_name, actual_model, base_url = self._get_provider_info(model) + + # Enforce scope restriction (case-insensitive). + if allowed_provider_names is not None and ( + provider_name.lower() not in {p.lower() for p in allowed_provider_names} + ): + error_message = f"API key is not permitted to use provider '{provider_name}'." + logger.error(error_message) + raise InvalidForgeKeyException(error=ValueError(error_message)) + logger.debug( f"Processing request for provider: {provider_name}, model: {actual_model}, endpoint: {endpoint}" ) @@ -513,7 +513,9 @@ async def process_request( # Get the provider's API key if provider_name not in self.provider_keys: - raise ValueError(f"No API key found for provider {provider_name}") + error_message = f"API key is not permitted to use provider '{provider_name}'." + logger.error(error_message) + raise InvalidForgeKeyException(error=ValueError(error_message)) serialized_api_key_config = self.provider_keys[provider_name]["api_key"] provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls(provider_name) @@ -534,9 +536,9 @@ async def process_request( elif "images/generations" in endpoint: # TODO: we only support openai for now if provider_name != "openai": - raise ValueError( - f"Unsupported endpoint: {endpoint} for provider {provider_name}" - ) + error_message = f"Unsupported endpoint: {endpoint} for provider {provider_name}" + logger.error(error_message) + raise NotImplementedError(error_message) result = await adapter.process_image_generation( endpoint, payload, @@ -545,9 +547,9 @@ async def process_request( elif "images/edits" in endpoint: # TODO: we only support openai for now if provider_name != "openai": - raise NotImplementedError( - f"Unsupported endpoint: {endpoint} for provider {provider_name}" - ) + error_message = f"Unsupported endpoint: {endpoint} for provider {provider_name}" + logger.error(error_message) + raise NotImplementedError(error_message) result = await adapter.process_image_edits( endpoint, payload, @@ -560,7 +562,9 @@ async def process_request( api_key, ) else: - raise ValueError(f"Unsupported endpoint: {endpoint}") + error_message = f"Unsupported endpoint: {endpoint}" + logger.error(error_message) + raise NotImplementedError(error_message) # Track usage statistics if it's not a streaming response if not inspect.isasyncgen(result): diff --git a/app/services/providers/anthropic_adapter.py b/app/services/providers/anthropic_adapter.py index 9cb2c7e..4533d31 100644 --- a/app/services/providers/anthropic_adapter.py +++ b/app/services/providers/anthropic_adapter.py @@ -7,8 +7,13 @@ import aiohttp +from app.core.logger import get_logger +from app.exceptions.exceptions import ProviderAPIException, InvalidCompletionRequestException + from .base import ProviderAdapter +logger = get_logger(name="anthropic_adapter") + ANTHROPIC_DEFAULT_MAX_TOKENS = 4096 @@ -58,21 +63,23 @@ def convert_openai_content_to_anthropic( if isinstance(content, str): return content - try: - result = [] - for msg in content: - _type = msg["type"] - if _type == "text": - result.append({"type": "text", "text": msg["text"]}) - elif _type == "image_url": - result.append( - AnthropicAdapter.convert_openai_image_content_to_anthropic(msg) - ) - else: - raise NotImplementedError(f"{_type} is not supported") - return result - except Exception as e: - raise NotImplementedError("Unsupported content type") from e + result = [] + for msg in content: + _type = msg["type"] + if _type == "text": + result.append({"type": "text", "text": msg["text"]}) + elif _type == "image_url": + result.append( + AnthropicAdapter.convert_openai_image_content_to_anthropic(msg) + ) + else: + error_message = f"{_type} is not supported" + logger.error(error_message) + raise InvalidCompletionRequestException( + provider_name="anthropic", + error=ValueError(error_message) + ) + return result async def list_models(self, api_key: str) -> list[str]: """List all models (verbosely) supported by the provider""" @@ -95,7 +102,12 @@ async def list_models(self, api_key: str) -> list[str]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Anthropic API error: {error_text}") + logger.error(f"List Models API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) resp = await response.json() self.CLAUDE_MODEL_MAPPING = { d["display_name"]: d["id"] for d in resp["data"] @@ -194,8 +206,11 @@ async def stream_response() -> AsyncGenerator[bytes, None]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError( - f"Anthropic API error: {response.status} - {error_text}" + logger.error(f"Completion Streaming API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text ) buffer = "" @@ -212,13 +227,11 @@ async def stream_response() -> AsyncGenerator[bytes, None]: elif line.startswith("data:"): data_str = line[len("data:") :].strip() - # logger.info(f"Anthropic Raw Event - Type: {event_type}, Data String: {data_str}") if not event_type or data_str is None: continue try: data = json.loads(data_str) - # logger.info(f"Anthropic Parsed Data: {data}") openai_chunk = None finish_reason = None # --- Event Processing Logic --- @@ -230,7 +243,6 @@ async def stream_response() -> AsyncGenerator[bytes, None]: captured_input_tokens = message_data["usage"].get( "input_tokens", 0 ) - # logger.info(f"Captured input_tokens: {captured_input_tokens}") captured_output_tokens = message_data["usage"].get( "output_tokens", captured_output_tokens ) @@ -267,7 +279,6 @@ async def stream_response() -> AsyncGenerator[bytes, None]: captured_output_tokens = usage_data_in_delta.get( "output_tokens", captured_output_tokens ) - # logger.info(f"Captured output_tokens from top-level delta usage: {captured_output_tokens}") if captured_input_tokens > 0: usage_info_complete = True @@ -282,7 +293,6 @@ async def stream_response() -> AsyncGenerator[bytes, None]: captured_output_tokens = usage.get( "output_tokens", captured_output_tokens ) - # logger.info(f"Captured usage from stop: in={captured_input_tokens}, out={captured_output_tokens}") if ( captured_input_tokens > 0 and captured_output_tokens > 0 @@ -315,13 +325,11 @@ async def stream_response() -> AsyncGenerator[bytes, None]: yield f"data: {json.dumps(usage_chunk)}\n\n".encode() # Reset flag to prevent duplicate yields usage_info_complete = False - # logger.info(f"Yielded usage chunk: {usage_chunk}") - except json.JSONDecodeError: - # logger.warning(f"Anthropic stream: Failed to parse JSON: {data_str}") + except json.JSONDecodeError as e: + logger.warning(f"Stream API error for {self.provider_name}: Failed to parse JSON: {e}") continue - except Exception: - # logger.error(f"Anthropic stream processing error: {e}", exc_info=True) + except Exception as e: continue # Final SSE message @@ -340,7 +348,12 @@ async def _process_regular_response( ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Anthropic API error: {error_text}") + logger.error(f"Completion API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) anthropic_response = await response.json() diff --git a/app/services/providers/azure_adapter.py b/app/services/providers/azure_adapter.py index 94f7506..04936b3 100644 --- a/app/services/providers/azure_adapter.py +++ b/app/services/providers/azure_adapter.py @@ -3,9 +3,10 @@ from .openai_adapter import OpenAIAdapter from app.core.logger import get_logger +from app.exceptions.exceptions import ProviderAPIException, BaseInvalidProviderSetupException # Configure logging -logger = get_logger(name="openai_adapter") +logger = get_logger(name="azure_adapter") class AzureAdapter(OpenAIAdapter): def __init__(self, provider_name: str, base_url: str, config: dict[str, Any]): @@ -18,15 +19,27 @@ def assert_valid_base_url(base_url: str) -> str: # base_url is required # e.g: https://THE-RESOURCE-NAME.openai.azure.com/ if not base_url: - raise ValueError("Azure base URL is required") + error_text = "Azure base URL is required" + logger.error(error_text) + raise BaseInvalidProviderSetupException( + provider_name="azure", + error=ValueError(error_text) + ) base_url = base_url.rstrip("/") return base_url @staticmethod def serialize_api_key_config(api_key: str, config: dict[str, Any] | None) -> str: """Serialize the API key for the given provider""" - assert config is not None - assert config.get("api_version") is not None + try: + assert config is not None + assert config.get("api_version") is not None + except AssertionError as e: + logger.error(str(e)) + raise BaseInvalidProviderSetupException( + provider_name="azure", + error=e + ) return json.dumps({ "api_key": api_key, @@ -36,9 +49,16 @@ def serialize_api_key_config(api_key: str, config: dict[str, Any] | None) -> str @staticmethod def deserialize_api_key_config(serialized_api_key_config: str) -> tuple[str, dict[str, Any] | None]: """Deserialize the API key for the given provider""" - deserialized_api_key_config = json.loads(serialized_api_key_config) - assert deserialized_api_key_config.get("api_key") is not None - assert deserialized_api_key_config.get("api_version") is not None + try: + deserialized_api_key_config = json.loads(serialized_api_key_config) + assert deserialized_api_key_config.get("api_key") is not None + assert deserialized_api_key_config.get("api_version") is not None + except Exception as e: + logger.error(str(e)) + raise BaseInvalidProviderSetupException( + provider_name="azure", + error=e + ) return deserialized_api_key_config["api_key"], { "api_version": deserialized_api_key_config["api_version"], diff --git a/app/services/providers/bedrock_adapter.py b/app/services/providers/bedrock_adapter.py index e38edf7..63a3f5d 100644 --- a/app/services/providers/bedrock_adapter.py +++ b/app/services/providers/bedrock_adapter.py @@ -8,6 +8,7 @@ from typing import Any from app.core.logger import get_logger +from app.exceptions.exceptions import BaseInvalidProviderSetupException, ProviderAPIException, InvalidCompletionRequestException, BaseForgeException from .base import ProviderAdapter @@ -27,13 +28,32 @@ class BedrockAdapter(ProviderAdapter): def __init__(self, provider_name: str, base_url: str, config: dict[str, str] | None = None): self._provider_name = provider_name self._base_url = base_url - assert "region_name" in config, "Bedrock region_name is required" + self.parse_config(config) + self._session = aiobotocore.session.get_session() + + @staticmethod + def validate_config(config: dict[str, str] | None): + """Validate the config for the given provider""" + + try: + assert config is not None, "Bedrock config is required" + assert "region_name" in config, "Bedrock region_name is required" + assert "aws_access_key_id" in config, "Bedrock aws_access_key_id is required" + assert "aws_secret_access_key" in config, "Bedrock aws_secret_access_key is required" + except AssertionError as e: + logger.error(str(e)) + raise BaseInvalidProviderSetupException( + provider_name="bedrock", + error=e + ) + + def parse_config(self, config: dict[str, str] | None) -> None: + """Parse the config for the given provider""" + + self.validate_config(config) self._region_name = config["region_name"] - assert "aws_access_key_id" in config, "Bedrock aws_access_key_id is required" self._aws_access_key_id = config["aws_access_key_id"] - assert "aws_secret_access_key" in config, "Bedrock aws_secret_access_key is required" self._aws_secret_access_key = config["aws_secret_access_key"] - self._session = aiobotocore.session.get_session() @property def client_ctx(self): @@ -50,10 +70,7 @@ def provider_name(self) -> str: @staticmethod def serialize_api_key_config(api_key: str, config: dict[str, Any] | None) -> str: """Serialize the API key for the given provider""" - assert config is not None - assert config.get("region_name") is not None - assert config.get("aws_access_key_id") is not None - assert config.get("aws_secret_access_key") is not None + BedrockAdapter.validate_config(config) return json.dumps({ "api_key": api_key, @@ -65,11 +82,18 @@ def serialize_api_key_config(api_key: str, config: dict[str, Any] | None) -> str @staticmethod def deserialize_api_key_config(serialized_api_key_config: str) -> tuple[str, dict[str, Any] | None]: """Deserialize the API key for the given provider""" - deserialized_api_key_config = json.loads(serialized_api_key_config) - assert deserialized_api_key_config.get("api_key") is not None - assert deserialized_api_key_config.get("region_name") is not None - assert deserialized_api_key_config.get("aws_access_key_id") is not None - assert deserialized_api_key_config.get("aws_secret_access_key") is not None + try: + deserialized_api_key_config = json.loads(serialized_api_key_config) + assert deserialized_api_key_config.get("api_key") is not None + assert deserialized_api_key_config.get("region_name") is not None + assert deserialized_api_key_config.get("aws_access_key_id") is not None + assert deserialized_api_key_config.get("aws_secret_access_key") is not None + except Exception as e: + logger.error(str(e)) + raise BaseInvalidProviderSetupException( + provider_name="bedrock", + error=e + ) return deserialized_api_key_config["api_key"], { "region_name": deserialized_api_key_config["region_name"], @@ -80,10 +104,7 @@ def deserialize_api_key_config(serialized_api_key_config: str) -> tuple[str, dic @staticmethod def mask_config(config: dict[str, Any] | None) -> dict[str, Any] | None: """Mask the config for the given provider""" - assert config is not None - assert config.get("region_name") is not None - assert config.get("aws_access_key_id") is not None - assert config.get("aws_secret_access_key") is not None + BedrockAdapter.validate_config(config) mask_str = "*" * 7 return { "region_name": config["region_name"][:3] + mask_str + config["region_name"][-3:], @@ -105,7 +126,13 @@ async def list_models(self, api_key: str) -> list[str]: try: response = await bedrock.list_foundation_models() except Exception as e: - raise ValueError(f"Bedrock API error: {e}") + error_text = f"List models API error for {self.provider_name}: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=400, + error_message=error_text + ) models = [r["modelId"] for r in response["modelSummaries"]] @@ -152,9 +179,22 @@ async def convert_openai_image_content_to_bedrock(msg: list[dict[str, Any]] | st }, } } + except aiohttp.ClientResponseError as e: + error_text = f"Bedrock API error: failed to download image from {data_url}: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name="bedrock", + error_code=e.status, + error_message=error_text + ) except Exception as e: - logger.warning(f"Bedrock API error: failed to download image from {data_url}: {e}") - raise Exception(f"Bedrock API error: {e}") + error_text = f"Bedrock API error: failed to download image from {data_url}: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name="bedrock", + error_code=500, + error_message=error_text + ) @staticmethod async def convert_openai_content_to_bedrock(content: list[dict[str, Any]] | str) -> list[dict[str, Any]]: @@ -171,10 +211,22 @@ async def convert_openai_content_to_bedrock(content: list[dict[str, Any]] | str) elif _type == "image_url": result.append(await BedrockAdapter.convert_openai_image_content_to_bedrock(msg)) else: - raise NotImplementedError(f"{_type} is not supported") + error_text = f"Bedrock API request error: {_type} is not supported" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name="bedrock", + error=ValueError(error_text) + ) return result + except BaseForgeException as e: + raise e except Exception as e: - raise NotImplementedError("Unsupported content type") from e + error_text = f"Bedrock API request error: {e}" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name="bedrock", + error=e + ) from e @staticmethod async def convert_openai_payload_to_bedrock(payload: dict[str, Any]) -> dict[str, Any]: @@ -224,7 +276,13 @@ async def _process_regular_response(self, bedrock_payload: dict[str, Any]) -> di **bedrock_payload, ) except Exception as e: - raise ValueError(f"Bedrock API error: {e}") + error_text = f"Completion API error for {self.provider_name}: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=400, + error_message=error_text + ) if message := response.get("output", {}).get("message"): completion_id = f"chatcmpl-{str(uuid.uuid4())}" @@ -239,7 +297,13 @@ async def _process_regular_response(self, bedrock_payload: dict[str, Any]) -> di if _type == "text": text_content += value else: - raise NotImplementedError(f"Bedrock API error: {_type} response is not supported") + error_text = f"Completion API error for {self.provider_name}: {_type} response is not supported" + logger.error(error_text) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=400, + error_message=error_text + ) input_tokens = response.get("usage", {}).get("inputTokens", 0) output_tokens = response.get("usage", {}).get("outputTokens", 0) @@ -365,7 +429,13 @@ async def _process_streaming_response(self, bedrock_payload: dict[str, Any]) -> if openai_chunk: yield f"data: {json.dumps(openai_chunk)}\n\n".encode() except Exception as e: - raise ValueError(f"Bedrock API error: {e}") + error_text = f"Streaming completion API error for {self.provider_name}: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=400, + error_message=error_text + ) if final_usage_data: openai_chunk = { "id": request_id, @@ -382,8 +452,10 @@ async def _process_streaming_response(self, bedrock_payload: dict[str, Any]) -> # Send final [DONE] message yield b"data: [DONE]\n\n" + except BaseForgeException as e: + raise e except Exception as e: - logger.error(f"Bedrock streaming API error: {str(e)}", exc_info=True) + logger.error(f"Streaming completion API error for {self.provider_name}: {e}", exc_info=True) error_chunk = { "id": str(uuid.uuid4()), "object": "chat.completion.chunk", diff --git a/app/services/providers/cohere_adapter.py b/app/services/providers/cohere_adapter.py index 602ce8d..27e120e 100644 --- a/app/services/providers/cohere_adapter.py +++ b/app/services/providers/cohere_adapter.py @@ -8,6 +8,7 @@ import aiohttp from app.core.logger import get_logger +from app.exceptions.exceptions import ProviderAPIException, BaseForgeException from .base import ProviderAdapter @@ -45,7 +46,12 @@ async def list_models(self, api_key: str) -> list[str]: ) as response: if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Cohere API error: {error_text}") + logger.error(f"List models API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) resp = await response.json() models = [d["name"] for d in resp["models"]] @@ -129,8 +135,12 @@ async def _stream_cohere_response( ): if response.status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Cohere API error: {response.status} - {error_text}") - raise ValueError(f"Cohere stream API error: {error_text}") + logger.error(f"Streaming completion API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) # Track the message ID for consistency message_id = None @@ -282,7 +292,7 @@ async def _stream_cohere_response( # # Send final [DONE] message yield b"data: [DONE]\n\n" except Exception as e: - logger.error(f"Cohere streaming API error: {str(e)}", exc_info=True) + logger.error(f"Streaming completion API error for {self.provider_name}: {e}", exc_info=True) error_chunk = { "id": uuid.uuid4(), "object": "chat.completion.chunk", @@ -307,7 +317,12 @@ async def _process_cohere_chat_completion( async with session.post(url, headers=headers, json=payload) as response: if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Cohere API error: {error_text}") + logger.error(f"Completion API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) resp = await response.json() return self._convert_cohere_to_openai(resp, payload["model"]) @@ -389,10 +404,22 @@ async def process_embeddings( async with session.post(url, headers=headers, json=cohere_payload) as response: if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Cohere API error: {error_text}") + logger.error(f"Embeddings API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) response_json = await response.json() return self.convert_cohere_embeddings_response_to_openai(response_json, payload["model"]) + except BaseForgeException as e: + raise e except Exception as e: - logger.error(f"Error in Cohere embeddings: {str(e)}", exc_info=True) - raise + error_text = f"Embeddings API error for {self.provider_name}: {e}" + logger.error(error_text, exc_info=True) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=500, + error_message=error_text + ) diff --git a/app/services/providers/google_adapter.py b/app/services/providers/google_adapter.py index 74228ed..8a44820 100644 --- a/app/services/providers/google_adapter.py +++ b/app/services/providers/google_adapter.py @@ -10,6 +10,8 @@ import aiohttp from app.core.logger import get_logger +from app.exceptions.exceptions import BaseForgeException, BaseInvalidRequestException, ProviderAPIException, InvalidCompletionRequestException, \ + InvalidEmbeddingsRequestException from .base import ProviderAdapter @@ -55,7 +57,12 @@ async def list_models(self, api_key: str) -> list[str]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"Google API error: {error_text}") + logger.error(f"List Models API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) resp = await response.json() self.GOOGLE_MODEL_MAPPING = { d["displayName"]: d["name"] for d in resp["models"] @@ -91,8 +98,12 @@ async def upload_file_to_gemini( # First, get the file metadata from the URL async with session.head(file_url) as response: if response.status != HTTPStatus.OK: - raise Exception( - f"Gemini Upload API error: Failed to fetch file metadata from URL: {response.status}" + error_text = await response.text() + logger.error(f"Gemini Upload API error: Failed to fetch file metadata from URL: {error_text}") + raise ProviderAPIException( + provider_name="google", + error_code=response.status, + error_message=error_text ) mime_type = response.headers.get("Content-Type", "application/octet-stream") @@ -116,13 +127,23 @@ async def upload_file_to_gemini( f"{base_url}?key={api_key}", headers=headers, json=metadata ) as response: if response.status != HTTPStatus.OK: - raise Exception( - f"Gemini Upload API error: Failed to initiate upload: {response.status}" + error_text = await response.text() + logger.error(f"Gemini Upload API error: Failed to initiate upload: {error_text}") + raise ProviderAPIException( + provider_name="google", + error_code=response.status, + error_message=error_text ) upload_url = response.headers.get("X-Goog-Upload-URL") if not upload_url: - raise Exception("No upload URL received from server") + error_text = "Gemini Upload API error: No upload URL received from server" + logger.error(error_text) + raise ProviderAPIException( + provider_name="google", + error_code=404, + error_message=error_text + ) # Upload the file content using streaming upload_headers = { @@ -134,16 +155,24 @@ async def upload_file_to_gemini( # Stream the file content directly from the source URL to Gemini API async with session.get(file_url) as source_response: if source_response.status != HTTPStatus.OK: - raise Exception( - f"Gemini Upload API error: Failed to fetch file content: {source_response.status}" + error_text = await source_response.text() + logger.error(f"Gemini Upload API error: Failed to fetch file content: {error_text}") + raise ProviderAPIException( + provider_name="google", + error_code=source_response.status, + error_message=error_text ) async with session.put( upload_url, headers=upload_headers, data=source_response.content ) as upload_response: if upload_response.status != HTTPStatus.OK: - raise Exception( - f"Gemini Upload API error: Failed to upload file: {upload_response.status}" + error_text = await upload_response.text() + logger.error(f"Gemini Upload API error: Failed to upload file: {error_text}") + raise ProviderAPIException( + provider_name="google", + error_code=upload_response.status, + error_message=error_text ) return await upload_response.json() @@ -179,9 +208,16 @@ async def convert_openai_image_content_to_google( "file_uri": result["file"]["uri"], } } - except Exception as e: - logger.error(f"Error uploading image to Google Gemini: {e}") + except ProviderAPIException as e: raise e + except Exception as e: + error_text = f"Error uploading image to Google Gemini: {e}" + logger.error(error_text) + raise ProviderAPIException( + provider_name="google", + error_code=400, + error_message=error_text + ) @staticmethod async def convert_openai_content_to_google( @@ -205,10 +241,21 @@ async def convert_openai_content_to_google( ) ) else: - raise NotImplementedError(f"{_type} is not supported") + error_text = f"{_type} is not supported" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name="google", + error=ValueError(error_text) + ) return result + except BaseForgeException as e: + raise e except Exception as e: - raise NotImplementedError("Unsupported content type") from e + logger.error(f"Error converting OpenAI content to Google: {e}") + raise BaseInvalidRequestException( + provider_name="google", + error=e + ) async def process_completion( self, @@ -256,10 +303,19 @@ async def _stream_google_response( try: if not google_payload: - raise ValueError("Empty payload for Google API request") + error_text = f"Empty payload for {self.provider_name} API request" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name=self.provider_name, + error=ValueError(error_text) + ) if not api_key: - raise ValueError("Missing API key for Google request") - + error_text = f"Missing API key for {self.provider_name} API request" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name=self.provider_name, + error=ValueError(error_text) + ) headers = {"Content-Type": "application/json", "Accept": "application/json"} logger.debug( f"Google API request - URL: {url}, Payload sample: {str(google_payload)[:200]}..." @@ -273,8 +329,12 @@ async def _stream_google_response( ): if response.status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Google API error: {response.status} - {error_text}") - raise ValueError(f"Google stream API error: {error_text}") + logger.error(f"Completion Streaming API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) # Process response in chunks buffer = "" @@ -432,7 +492,6 @@ async def _process_google_chat_completion( ) -> dict[str, Any]: """Process a regular (non-streaming) chat completion with Google Gemini""" model = payload.get("model", "") - logger.info(f"Processing regular chat completion for model: {model}") # Convert payload to Google format google_payload = await self.convert_openai_completion_payload_to_google(payload, api_key) @@ -441,7 +500,6 @@ async def _process_google_chat_completion( model_path = model if model.startswith("models/") else f"models/{model}" url = f"{self._base_url}/{model_path}:generateContent" - logger.info(f"Sending request to Google API: {url}") try: # Make the API request @@ -449,10 +507,12 @@ async def _process_google_chat_completion( # Check for API key if not api_key: - raise ValueError("Missing API key for Google request") - - logger.debug(f"Google API request - Headers: {headers}") - logger.debug(f"Google API request - URL: {url}") + error_text = f"Missing API key for {self.provider_name} API request" + logger.error(error_text) + raise InvalidCompletionRequestException( + provider_name=self.provider_name, + error=ValueError(error_text) + ) async with ( aiohttp.ClientSession() as session, @@ -461,25 +521,27 @@ async def _process_google_chat_completion( ) as response, ): response_status = response.status - logger.info(f"Google API response status: {response_status}") - if response_status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Google API error: {response_status} - {error_text}") - - raise ValueError(f"Google API error: {error_text}") + logger.error(f"Completion API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response_status, + error_message=error_text + ) response_json = await response.json() - logger.debug( - f"Google API response: {json.dumps(response_json)[:200]}..." - ) # Convert to OpenAI format return self.convert_google_completion_response_to_openai(response_json, model) - + except BaseForgeException as e: + raise e except Exception as e: logger.error(f"Error in Google chat completion: {str(e)}", exc_info=True) - raise + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=e + ) @staticmethod def convert_google_completion_response_to_openai( @@ -600,7 +662,6 @@ async def process_embeddings( model_path = model if model.startswith("models/") else f"models/{model}" url = f"{self._base_url}/{model_path}:embedContent" - logger.info(f"Sending request to Google API: {url}") # Convert payload to Google format google_payload = self.convert_openai_embeddings_payload_to_google(payload) @@ -610,7 +671,12 @@ async def process_embeddings( # Check for API key if not api_key: - raise ValueError("Missing API key for Google request") + error_text = f"Missing API key for {self.provider_name} API request" + logger.error(error_text) + raise InvalidEmbeddingsRequestException( + provider_name=self.provider_name, + error=ValueError(error_text) + ) async with ( aiohttp.ClientSession() as session, @@ -619,17 +685,22 @@ async def process_embeddings( ) as response, ): response_status = response.status - logger.info(f"Google API response status: {response_status}") - if response_status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Google API error: {response_status} - {error_text}") - - raise ValueError(f"Google API error: {error_text}") + logger.error(f"Embeddings API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response_status, + error_message=error_text + ) response_json = await response.json() return self.convert_google_embeddings_response_to_openai(response_json, model) - + except BaseForgeException as e: + raise e except Exception as e: - logger.error(f"Error in Google embeddings: {str(e)}", exc_info=True) - raise + logger.error(f"Error in {self.provider_name} embeddings: {str(e)}", exc_info=True) + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=e + ) diff --git a/app/services/providers/openai_adapter.py b/app/services/providers/openai_adapter.py index fae3bf3..933778b 100644 --- a/app/services/providers/openai_adapter.py +++ b/app/services/providers/openai_adapter.py @@ -4,6 +4,7 @@ import aiohttp from app.core.logger import get_logger +from app.exceptions.exceptions import ProviderAPIException, BaseInvalidRequestException from .base import ProviderAdapter @@ -27,8 +28,7 @@ def __init__( def provider_name(self) -> str: return self._provider_name - @staticmethod - def get_model_id(payload: dict[str, Any]) -> str: + def get_model_id(self, payload: dict[str, Any]) -> str: """Get the model ID from the payload""" if "id" in payload: return payload["id"] @@ -37,7 +37,11 @@ def get_model_id(payload: dict[str, Any]) -> str: elif "model_id" in payload: return payload["model_id"] else: - raise ValueError("Model ID not found in payload") + logger.error(f"Model ID not found in payload for {self.provider_name}") + raise BaseInvalidRequestException( + provider_name=self.provider_name, + error=ValueError("Model ID not found in payload") + ) async def list_models( self, @@ -66,7 +70,12 @@ async def list_models( ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"{self.provider_name} API error: {error_text}") + logger.error(f"List Models API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) resp = await response.json() # Better compatibility with Forge @@ -113,15 +122,17 @@ async def stream_response() -> AsyncGenerator[bytes, None]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError( - f"{self.provider_name} API error: {error_text}" + logger.error(f"Completion Streaming API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text ) # Stream the response back async for chunk in response.content: if self.provider_name == "azure": chunk = self.process_streaming_chunk(chunk) - logger.info(f"OpenAI streaming chunk: {chunk}") if chunk: yield chunk @@ -137,7 +148,12 @@ async def stream_response() -> AsyncGenerator[bytes, None]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"{self.provider_name} API error: {error_text}") + logger.error(f"Completion API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) return await response.json() @@ -161,7 +177,12 @@ async def process_image_generation( ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"{self.provider_name} API error: {error_text}") + logger.error(f"Image Generation API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) return await response.json() @@ -185,7 +206,12 @@ async def process_image_edits( ): if response.status != HTTPStatus.OK: error_text = await response.text() - raise ValueError(f"{self.provider_name} API error: {error_text}") + logger.error(f"API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) return await response.json() @@ -211,16 +237,17 @@ async def process_embeddings( url = f"{base_url or self._base_url}/{endpoint}" query_params = query_params or {} - try: - async with ( - aiohttp.ClientSession() as session, - session.post(url, headers=headers, json=payload, params=query_params) as response, - ): - if response.status != HTTPStatus.OK: - error_text = await response.text() - raise ValueError(f"{self.provider_name} API error: {error_text}") + async with ( + aiohttp.ClientSession() as session, + session.post(url, headers=headers, json=payload, params=query_params) as response, + ): + if response.status != HTTPStatus.OK: + error_text = await response.text() + logger.error(f"Embeddings API error for {self.provider_name}: {error_text}") + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text + ) - return await response.json() - except Exception as e: - logger.error(f"Error in OpenAI embeddings: {str(e)}", exc_info=True) - raise + return await response.json() diff --git a/tests/unit_tests/test_provider_service_images.py b/tests/unit_tests/test_provider_service_images.py index 936e8fd..1451585 100644 --- a/tests/unit_tests/test_provider_service_images.py +++ b/tests/unit_tests/test_provider_service_images.py @@ -117,7 +117,7 @@ async def test_process_request_images_generations_routing( mock_anthropic_adapter.process_image_generation.side_effect = ValueError( "Unsupported endpoint: images/generations for provider anthropic" ) - with self.assertRaises(ValueError) as context: + with self.assertRaises(NotImplementedError) as context: await service.process_request( "images/generations", {"model": "claude-3-opus"} )