Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@ RUN uv sync --no-dev
#################
FROM python:3.13-alpine

# Create non-root user
RUN addgroup -S appgroup && adduser -S appuser -G appgroup

# Copy the app from build stage
COPY --from=build /app /app
COPY --from=build --chown=appuser:appgroup /app /app

WORKDIR /app

USER appuser

EXPOSE 8080

CMD ["/app/.venv/bin/uvicorn", "modai.main:app", "--host", "0.0.0.0", "--port", "8080"]
3 changes: 3 additions & 0 deletions backend/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ modules:
openai: chat_openai
module_dependencies:
chat_openai: chat_openai
session: "session"
chat_openai:
class: modai.modules.chat.openai_llm_chat.OpenAILLMChatModule
module_dependencies:
Expand All @@ -21,10 +22,12 @@ modules:
class: modai.modules.model_provider.openai_provider.OpenAIProviderModule
module_dependencies:
llm_provider_store: "model_provider_store"
session: "session"
central_model_provider_router:
class: modai.modules.model_provider.central_router.CentralModelProviderRouter
module_dependencies:
openai_provider: "openai_model_provider"
session: "session"
user_store:
class: modai.modules.user_store.sql_model_user_store.SQLAlchemyUserStore
config:
Expand Down
7 changes: 7 additions & 0 deletions backend/src/modai/modules/chat/web_chat_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, cast
from modai.module import ModuleDependencies
from .module import ChatLLMModule, ChatWebModule as ChatWebModuleBase
from modai.modules.session.module import SessionModule
import openai
from openai.types.responses import ResponseCreateParams

Expand All @@ -15,6 +16,10 @@ class ChatWebModule(ChatWebModuleBase):
def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]):
super().__init__(dependencies, config)
# Router is already set up in base class with response_model=None
self.session_module: SessionModule = dependencies.modules.get("session")
if not self.session_module:
raise ValueError("ChatWebModule requires a 'session' module dependency")

clients_config = config.get("clients", {})
self.clients: Dict[str, ChatLLMModule] = {}
for prefix, module_name in clients_config.items():
Expand All @@ -34,6 +39,8 @@ async def responses_endpoint(
"""
Routes the chat request to the appropriate LLM module based on the model prefix.
"""
self.session_module.validate_session_for_http(request)

model = body_json.get("model", "")

if not model:
Expand Down
24 changes: 19 additions & 5 deletions backend/src/modai/modules/model_provider/central_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from typing import Any, List, Optional
from fastapi import APIRouter, Query
from fastapi import APIRouter, Query, Request
from pydantic import BaseModel

from modai.module import ModaiModule, ModuleDependencies
Expand All @@ -13,6 +13,7 @@
ModelProviderModule,
Model,
)
from modai.modules.session.module import SessionModule


class ModelsListResponse(BaseModel):
Expand Down Expand Up @@ -40,6 +41,12 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]):
super().__init__(dependencies, config)
self.router = APIRouter()

self.session_module: SessionModule = dependencies.modules.get("session")
if not self.session_module:
raise ValueError(
"CentralModelProviderRouter requires a 'session' module dependency"
)

# Add the central route for getting all providers
self.router.add_api_route(
"/api/v1/models/providers",
Expand All @@ -56,6 +63,7 @@ def __init__(self, dependencies: ModuleDependencies, config: dict[str, Any]):

async def get_all_providers(
self,
request: Request,
limit: Optional[int] = Query(
None, ge=1, le=1000, description="Maximum number of providers to return"
),
Expand All @@ -73,6 +81,8 @@ async def get_all_providers(
Returns:
ModelProvidersAllResponse with providers list and pagination info
"""
self.session_module.validate_session_for_http(request)

all_providers = []

# Get all provider modules from dependencies
Expand All @@ -87,7 +97,7 @@ async def get_all_providers(
# Call the get_providers method on each provider module
# But we need to modify it to not apply pagination per module
providers_response = await provider_module.get_providers(
limit=None, offset=None
request, limit=None, offset=None
)
all_providers.extend(providers_response.providers)
except Exception as e:
Expand All @@ -111,14 +121,16 @@ async def get_all_providers(
offset=offset,
)

async def get_all_models(self) -> ModelsListResponse:
async def get_all_models(self, request: Request) -> ModelsListResponse:
"""
Get all models from all providers across all provider types.
Returns in OpenAI-compatible format.

Returns:
ModelsListResponse with all available models
"""
self.session_module.validate_session_for_http(request)

all_models = []

# Get all provider modules from dependencies
Expand All @@ -131,13 +143,15 @@ async def get_all_models(self) -> ModelsListResponse:
try:
# Get all providers for this module type
providers_response = await provider_module.get_providers(
limit=None, offset=None
request, limit=None, offset=None
)

# For each provider, get its models
for provider in providers_response.providers:
try:
models_response = await provider_module.get_models(provider.id)
models_response = await provider_module.get_models(
request, provider.id
)

# Add models with prefixed IDs to avoid conflicts
for model_data in models_response.data:
Expand Down
34 changes: 26 additions & 8 deletions backend/src/modai/modules/model_provider/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from abc import ABC, abstractmethod
from typing import Any, List, Optional
from fastapi import APIRouter, Query
from fastapi import APIRouter, Query, Request
from pydantic import BaseModel

from modai.module import ModaiModule, ModuleDependencies
Expand Down Expand Up @@ -116,6 +116,7 @@ def __init__(
@abstractmethod
async def get_providers(
self,
request: Request,
limit: Optional[int] = Query(
None, ge=1, le=1000, description="Maximum number of providers to return"
),
Expand All @@ -127,98 +128,115 @@ async def get_providers(
Get all model providers with optional pagination.

Args:
request: FastAPI request object
limit: Maximum number of providers to return
offset: Number of providers to skip

Returns:
ModelProvidersListResponse: List of providers with pagination info

Raises:
HTTPException: 401 if not authenticated
HTTPException: 500 if retrieval fails
"""
pass

@abstractmethod
async def get_provider(self, provider_id: str) -> ModelProviderResponse:
async def get_provider(
self, request: Request, provider_id: str
) -> ModelProviderResponse:
"""
Get a specific model provider by ID.

Args:
request: FastAPI request object
provider_id: Unique identifier for the provider

Returns:
ModelProviderResponse: Provider data

Raises:
HTTPException: 401 if not authenticated
HTTPException: 404 if provider not found, 500 if retrieval fails
"""
pass

@abstractmethod
async def create_provider(
self, request: ModelProviderCreateRequest
self, request: Request, provider_data: ModelProviderCreateRequest
) -> ModelProviderResponse:
"""
Create a new model provider.

Args:
request: Provider data
request: FastAPI request object
provider_data: Provider data

Returns:
ModelProviderResponse: Created provider data

Raises:
HTTPException: 401 if not authenticated
HTTPException: 400 for validation errors, 409 for conflicts, 500 for other failures
"""
pass

@abstractmethod
async def update_provider(
self, provider_id: str, request: ModelProviderCreateRequest
self,
request: Request,
provider_id: str,
provider_data: ModelProviderCreateRequest,
) -> ModelProviderResponse:
"""
Update an existing model provider.

Args:
request: FastAPI request object
provider_id: The ID of the provider to update
request: Provider data
provider_data: Provider data

Returns:
ModelProviderResponse: Updated provider data

Raises:
HTTPException: 401 if not authenticated
HTTPException: 400 for validation errors, 404 if provider not found, 409 for conflicts, 500 for other failures
"""
pass

@abstractmethod
async def get_models(self, provider_id: str) -> ModelResponse:
async def get_models(self, request: Request, provider_id: str) -> ModelResponse:
"""
Get available models from a specific provider.

Args:
request: FastAPI request object
provider_id: Unique identifier for the provider

Returns:
ModelResponse: Models data from the provider

Raises:
HTTPException: 401 if not authenticated
HTTPException: 404 if provider not found, 500 if retrieval fails
"""
pass

@abstractmethod
async def delete_provider(self, provider_id: str) -> None:
async def delete_provider(self, request: Request, provider_id: str) -> None:
"""
Delete a model provider.

Args:
request: FastAPI request object
provider_id: ID of the provider to delete

Returns:
None (204 No Content)

Raises:
HTTPException: 401 if not authenticated
HTTPException: 500 if deletion fails
"""
pass
Loading