Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions app/api/schemas/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

from pydantic import BaseModel

from app.core.logger import get_logger
from app.exceptions.exceptions import InvalidCompletionRequestException

logger = get_logger(name="openai_schemas")


class OpenAIContentImageUrlModel(BaseModel):
url: str
Expand All @@ -21,17 +26,35 @@ class OpenAIContentModel(BaseModel):
def __init__(self, **data: Any):
super().__init__(**data)
if self.type not in ["text", "image_url", "input_audio"]:
raise ValueError(
f"Invalid type: {self.type}. Must be one of: text, image_url, input_audio"
error_message = f"Invalid type: {self.type}. Must be one of: text, image_url, input_audio"
logger.error(error_message)
raise InvalidCompletionRequestException(
provider_name="openai",
error=ValueError(error_message)
)

# Validate that the appropriate field is set based on type
if self.type == "text" and self.text is None:
raise ValueError("text field must be set when type is 'text'")
error_message = "text field must be set when type is 'text'"
logger.error(error_message)
raise InvalidCompletionRequestException(
provider_name="openai",
error=ValueError(error_message)
)
if self.type == "image_url" and self.image_url is None:
raise ValueError("image_url field must be set when type is 'image_url'")
error_message = "image_url field must be set when type is 'image_url'"
logger.error(error_message)
raise InvalidCompletionRequestException(
provider_name="openai",
error=ValueError(error_message)
)
if self.type == "input_audio" and self.input_audio is None:
raise ValueError("input_audio field must be set when type is 'input_audio'")
error_message = "input_audio field must be set when type is 'input_audio'"
logger.error(error_message)
raise InvalidCompletionRequestException(
provider_name="openai",
error=ValueError(error_message)
)


class ChatMessage(BaseModel):
Expand Down
81 changes: 79 additions & 2 deletions app/exceptions/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,83 @@
class InvalidProviderException(Exception):
class BaseForgeException(Exception):
pass

class InvalidProviderException(BaseForgeException):
"""Exception raised when a provider is invalid."""

def __init__(self, identifier: str):
self.identifier = identifier
super().__init__(f"Provider {identifier} is invalid or failed to extract provider info from model_id {identifier}")
super().__init__(f"Provider {identifier} is invalid or failed to extract provider info from model_id {identifier}")


class ProviderAuthenticationException(BaseForgeException):
"""Exception raised when a provider authentication fails."""

def __init__(self, provider_name: str, error: Exception):
self.provider_name = provider_name
self.error = error
super().__init__(f"Provider {provider_name} authentication failed: {error}")


class BaseInvalidProviderSetupException(BaseForgeException):
"""Exception raised when a provider setup is invalid."""

def __init__(self, provider_name: str, error: Exception):
self.provider_name = provider_name
self.error = error
super().__init__(f"Provider {provider_name} setup is invalid: {error}")

class InvalidProviderConfigException(BaseInvalidProviderSetupException):
"""Exception raised when a provider config is invalid."""

def __init__(self, provider_name: str, error: Exception):
super().__init__(provider_name, error)

class InvalidProviderAPIKeyException(BaseInvalidProviderSetupException):
"""Exception raised when a provider API key is invalid."""

def __init__(self, provider_name: str, error: Exception):
super().__init__(provider_name, error)

class ProviderAPIException(BaseForgeException):
"""Exception raised when a provider API error occurs."""

def __init__(self, provider_name: str, error_code: int, error_message: str):
super().__init__(f"Provider {provider_name} API error: {error_code} {error_message}")


class BaseInvalidRequestException(BaseForgeException):
"""Exception raised when a request is invalid."""

def __init__(self, provider_name: str, error: Exception):
self.provider_name = provider_name
self.error = error
super().__init__(f"Provider {provider_name} request is invalid: {error}")

class InvalidCompletionRequestException(BaseInvalidRequestException):
"""Exception raised when a completion request is invalid."""

def __init__(self, provider_name: str, error: Exception):
self.provider_name = provider_name
self.error = error
super().__init__(f"Provider {provider_name} completion request is invalid: {error}")

class InvalidEmbeddingsRequestException(BaseInvalidRequestException):
"""Exception raised when a embeddings request is invalid."""

def __init__(self, provider_name: str, error: Exception):
self.provider_name = provider_name
self.error = error
super().__init__(f"Provider {provider_name} embeddings request is invalid: {error}")

class BaseInvalidForgeKeyException(BaseForgeException):
"""Exception raised when a Forge key is invalid."""

def __init__(self, error: Exception):
self.error = error
super().__init__(f"Forge key is invalid: {error}")


class InvalidForgeKeyException(BaseInvalidForgeKeyException):
"""Exception raised when a Forge key is invalid."""
def __init__(self, error: Exception):
super().__init__(error)
63 changes: 62 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Callable

from dotenv import load_dotenv
from fastapi import APIRouter, FastAPI, Request
from fastapi import APIRouter, FastAPI, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware

Expand All @@ -20,6 +20,8 @@
from app.core.database import engine
from app.core.logger import get_logger
from app.models.base import Base
from app.exceptions.exceptions import ProviderAuthenticationException, InvalidProviderException, BaseInvalidProviderSetupException, \
ProviderAPIException, BaseInvalidRequestException, BaseInvalidForgeKeyException

load_dotenv()

Expand Down Expand Up @@ -77,6 +79,65 @@ async def dispatch(self, request: Request, call_next: Callable):
openapi_url="/openapi.json" if not is_production else None,
)

### Exception handlers block ###

# Add exception handler for ProviderAuthenticationException
@app.exception_handler(ProviderAuthenticationException)
async def provider_authentication_exception_handler(request: Request, exc: ProviderAuthenticationException):
return HTTPException(
status_code=401,
detail=f"Authentication failed for provider {exc.provider_name}"
)

# Add exception handler for InvalidProviderException
@app.exception_handler(InvalidProviderException)
async def invalid_provider_exception_handler(request: Request, exc: InvalidProviderException):
return HTTPException(
status_code=400,
detail=f"{str(exc)}. Please verify your provider and model details by calling the /models endpoint or visiting https://tensorblock.co/api-docs/model-ids, and ensure you’re using a valid provider name, model name, and model ID."
)

# Add exception handler for BaseInvalidProviderSetupException
@app.exception_handler(BaseInvalidProviderSetupException)
async def base_invalid_provider_setup_exception_handler(request: Request, exc: BaseInvalidProviderSetupException):
return HTTPException(
status_code=400,
detail=str(exc)
)

# Add exception handler for ProviderAPIException
@app.exception_handler(ProviderAPIException)
async def provider_api_exception_handler(request: Request, exc: ProviderAPIException):
return HTTPException(
status_code=exc.error_code,
detail=f"Provider API error: {exc.provider_name} {exc.error_code} {exc.error_message}"
)

# Add exception handler for BaseInvalidRequestException
@app.exception_handler(BaseInvalidRequestException)
async def base_invalid_request_exception_handler(request: Request, exc: BaseInvalidRequestException):
return HTTPException(
status_code=400,
detail=str(exc)
)

# Add exception handler for BaseInvalidForgeKeyException
@app.exception_handler(BaseInvalidForgeKeyException)
async def base_invalid_forge_key_exception_handler(request: Request, exc: BaseInvalidForgeKeyException):
return HTTPException(
status_code=401,
detail=f"Invalid Forge key: {exc.error}"
)

# Add exception handler for NotImplementedError
@app.exception_handler(NotImplementedError)
async def not_implemented_error_handler(request: Request, exc: NotImplementedError):
return HTTPException(
status_code=404,
detail=f"Not implemented: {exc}"
)
### Exception handlers block ends ###


# Middleware to log slow requests
@app.middleware("http")
Expand Down
62 changes: 33 additions & 29 deletions app/services/provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from app.core.logger import get_logger
from app.core.security import decrypt_api_key, encrypt_api_key
from app.exceptions.exceptions import InvalidProviderException
from app.exceptions.exceptions import InvalidProviderException, BaseInvalidRequestException, InvalidForgeKeyException
from app.models.user import User
from app.services.usage_stats_service import UsageStatsService

Expand Down Expand Up @@ -322,6 +322,7 @@ def _get_provider_info_with_prefix(
)

if not matching_provider:
logger.error(f"No matching provider found for {original_model}")
raise InvalidProviderException(original_model)

provider_data = self.provider_keys[matching_provider]
Expand Down Expand Up @@ -358,16 +359,17 @@ def _find_provider_for_unprefixed_model(
provider_data.get("base_url"),
)

logger.error(f"No matching provider found for {model}")
raise InvalidProviderException(model)

def _get_provider_info(self, model: str) -> tuple[str, str, str | None]:
"""
Determine the provider based on the model name.
"""
if not self._keys_loaded:
raise RuntimeError(
"Provider keys must be loaded before calling _get_provider_info. Call _load_provider_keys_async() first."
)
error_message = "Provider keys must be loaded before calling _get_provider_info. Call _load_provider_keys_async() first."
logger.error(error_message)
raise RuntimeError(error_message)

provider_name, model_name_no_prefix = self._extract_provider_name_prefix(model)

Expand Down Expand Up @@ -485,25 +487,23 @@ async def process_request(

model = payload.get("model")
if not model:
raise ValueError("Model is required")

try:
provider_name, actual_model, base_url = self._get_provider_info(model)

# Enforce scope restriction (case-insensitive).
if allowed_provider_names is not None and (
provider_name.lower() not in {p.lower() for p in allowed_provider_names}
):
raise ValueError(
f"API key is not permitted to use provider '{provider_name}'."
)
except ValueError as e:
# Use parameterized logging to avoid issues if the error message contains braces
logger.error("Error getting provider info for model {}: {}", model, str(e))
raise ValueError(
f"Invalid model ID: {model}. Please check your model configuration."
error_message = "Model is required"
logger.error(error_message)
raise BaseInvalidRequestException(
provider_name="unknown",
error=ValueError(error_message)
)

provider_name, actual_model, base_url = self._get_provider_info(model)

# Enforce scope restriction (case-insensitive).
if allowed_provider_names is not None and (
provider_name.lower() not in {p.lower() for p in allowed_provider_names}
):
error_message = f"API key is not permitted to use provider '{provider_name}'."
logger.error(error_message)
raise InvalidForgeKeyException(error=ValueError(error_message))

logger.debug(
f"Processing request for provider: {provider_name}, model: {actual_model}, endpoint: {endpoint}"
)
Expand All @@ -513,7 +513,9 @@ async def process_request(

# Get the provider's API key
if provider_name not in self.provider_keys:
raise ValueError(f"No API key found for provider {provider_name}")
error_message = f"API key is not permitted to use provider '{provider_name}'."
logger.error(error_message)
raise InvalidForgeKeyException(error=ValueError(error_message))

serialized_api_key_config = self.provider_keys[provider_name]["api_key"]
provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls(provider_name)
Expand All @@ -534,9 +536,9 @@ async def process_request(
elif "images/generations" in endpoint:
# TODO: we only support openai for now
if provider_name != "openai":
raise ValueError(
f"Unsupported endpoint: {endpoint} for provider {provider_name}"
)
error_message = f"Unsupported endpoint: {endpoint} for provider {provider_name}"
logger.error(error_message)
raise NotImplementedError(error_message)
result = await adapter.process_image_generation(
endpoint,
payload,
Expand All @@ -545,9 +547,9 @@ async def process_request(
elif "images/edits" in endpoint:
# TODO: we only support openai for now
if provider_name != "openai":
raise NotImplementedError(
f"Unsupported endpoint: {endpoint} for provider {provider_name}"
)
error_message = f"Unsupported endpoint: {endpoint} for provider {provider_name}"
logger.error(error_message)
raise NotImplementedError(error_message)
result = await adapter.process_image_edits(
endpoint,
payload,
Expand All @@ -560,7 +562,9 @@ async def process_request(
api_key,
)
else:
raise ValueError(f"Unsupported endpoint: {endpoint}")
error_message = f"Unsupported endpoint: {endpoint}"
logger.error(error_message)
raise NotImplementedError(error_message)

# Track usage statistics if it's not a streaming response
if not inspect.isasyncgen(result):
Expand Down
Loading
Loading