diff --git a/alembic/versions/9daf34d338f7_update_model_mapping_type_for_.py b/alembic/versions/9daf34d338f7_update_model_mapping_type_for_.py new file mode 100644 index 0000000..4428b60 --- /dev/null +++ b/alembic/versions/9daf34d338f7_update_model_mapping_type_for_.py @@ -0,0 +1,42 @@ +"""update model_mapping type for ProviderKey table + +Revision ID: 9daf34d338f7 +Revises: 08cc005a4bc8 +Create Date: 2025-07-18 21:32:48.791253 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision = "9daf34d338f7" +down_revision = "08cc005a4bc8" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Change model_mapping column from String to JSON + op.alter_column( + "provider_keys", + "model_mapping", + existing_type=sa.String(), + type_=postgresql.JSON(astext_type=sa.Text()), + existing_nullable=True, + postgresql_using="model_mapping::json", + ) + + +def downgrade() -> None: + # Change model_mapping column from JSON back to String + op.alter_column( + "provider_keys", + "model_mapping", + existing_type=postgresql.JSON(astext_type=sa.Text()), + type_=sa.String(), + existing_nullable=True, + postgresql_using="model_mapping::text", + ) diff --git a/app/api/dependencies.py b/app/api/dependencies.py index dcd6084..f8333ae 100644 --- a/app/api/dependencies.py +++ b/app/api/dependencies.py @@ -12,8 +12,10 @@ from fastapi import Depends, HTTPException, Request, status from fastapi.security import APIKeyHeader, OAuth2PasswordBearer from jose import JWTError, jwt +from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session, joinedload +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session, joinedload, selectinload from app.api.schemas.user import TokenData from app.core.async_cache import ( @@ -23,7 +25,7 @@ forge_scope_cache_async, get_forge_scope_cache_async, ) -from app.core.database import get_db +from app.core.database import get_db, get_async_db from app.core.logger import get_logger from app.core.security import ( ALGORITHM, @@ -91,7 +93,7 @@ async def fetch_and_cache_jwks() -> list | None: async def get_current_user( - db: Session = Depends(get_db), token: str = Depends(oauth2_scheme) + db: AsyncSession = Depends(get_async_db), token: str = Depends(oauth2_scheme) ): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -106,7 +108,9 @@ async def get_current_user( token_data = TokenData(username=username) except JWTError as err: raise credentials_exception from err - user = db.query(User).filter(User.username == token_data.username).first() + + result = await db.execute(select(User).filter(User.username == token_data.username)) + user = result.scalar_one_or_none() if user is None: raise credentials_exception return user @@ -143,7 +147,7 @@ async def get_api_key_from_headers(request: Request) -> str: async def get_user_by_api_key( request: Request = None, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> User: """Get user by API key from headers, with caching""" api_key_from_header = await get_api_key_from_headers(request) @@ -183,12 +187,12 @@ async def get_user_by_api_key( # avoids an extra query later in /models. cached_scope = await get_forge_scope_cache_async(api_key) - api_key_record = ( - db.query(ForgeApiKey) - .options(joinedload(ForgeApiKey.allowed_provider_keys)) + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) .filter(ForgeApiKey.key == api_key_from_header, ForgeApiKey.is_active) - .first() ) + api_key_record = result.scalar_one_or_none() if not api_key_record: raise HTTPException( @@ -197,12 +201,12 @@ async def get_user_by_api_key( ) # Get the user associated with this API key and EAGER LOAD all provider keys - user = ( - db.query(User) - .options(joinedload(User.provider_keys)) + result = await db.execute( + select(User) + .options(selectinload(User.provider_keys)) .filter(User.id == api_key_record.user_id) - .first() ) + user = result.scalar_one_or_none() if not user: raise HTTPException( @@ -230,7 +234,7 @@ async def get_user_by_api_key( # Update last used timestamp for the API key api_key_record.last_used_at = datetime.utcnow() - db.commit() + await db.commit() # Cache the user data for future requests await cache_user_async(api_key, user) @@ -338,7 +342,7 @@ async def validate_clerk_jwt(token: str = Depends(clerk_token_header)): async def get_current_user_from_clerk( - db: Session = Depends(get_db), token_payload: dict = Depends(validate_clerk_jwt) + db: AsyncSession = Depends(get_async_db), token_payload: dict = Depends(validate_clerk_jwt) ): """Get the current user from Clerk token, creating if needed""" from urllib.parse import quote @@ -352,7 +356,8 @@ async def get_current_user_from_clerk( ) # Find user by clerk_user_id - user = db.query(User).filter(User.clerk_user_id == clerk_user_id).first() + result = await db.execute(select(User).filter(User.clerk_user_id == clerk_user_id)) + user = result.scalar_one_or_none() # User doesn't exist yet, create one if not user: @@ -398,7 +403,8 @@ async def get_current_user_from_clerk( username = email # Check if username exists and make unique if needed - existing_user = db.query(User).filter(User.username == username).first() + result = await db.execute(select(User).filter(User.username == username)) + existing_user = result.scalar_one_or_none() if existing_user: import random @@ -412,20 +418,22 @@ async def get_current_user_from_clerk( username = clerk_user_id # Check if user exists with this email - existing_user = db.query(User).filter(User.email == email).first() + result = await db.execute(select(User).filter(User.email == email)) + existing_user = result.scalar_one_or_none() if existing_user: # Link existing user to Clerk ID try: existing_user.clerk_user_id = clerk_user_id - db.commit() + await db.commit() return existing_user except IntegrityError: # Another request might have already linked this user or created a new one - db.rollback() + await db.rollback() # Retry the query to get the user - user = ( - db.query(User).filter(User.clerk_user_id == clerk_user_id).first() + result = await db.execute( + select(User).filter(User.clerk_user_id == clerk_user_id) ) + user = result.scalar_one_or_none() if user: return user # If still no user, continue with creation attempt @@ -439,17 +447,15 @@ async def get_current_user_from_clerk( username=username, clerk_user_id=clerk_user_id, is_active=True, - hashed_password=get_password_hash( - "CLERK_AUTH_USER" - ), # Add placeholder password for Clerk users + hashed_password="", # Clerk handles authentication ) db.add(user) - db.commit() - db.refresh(user) + await db.commit() + await db.refresh(user) # Create default TensorBlock provider for the new user try: - create_default_tensorblock_provider_for_user(user.id, db) + await create_default_tensorblock_provider_for_user(user.id, db) except Exception as e: # Log error but don't fail user creation logger.warning( @@ -459,12 +465,13 @@ async def get_current_user_from_clerk( return user except IntegrityError as e: # Handle race condition: another request might have created the user - db.rollback() + await db.rollback() if "users_clerk_user_id_key" in str(e) or "clerk_user_id" in str(e): # Retry the query to get the user that was created by another request - user = ( - db.query(User).filter(User.clerk_user_id == clerk_user_id).first() + result = await db.execute( + select(User).filter(User.clerk_user_id == clerk_user_id) ) + user = result.scalar_one_or_none() if user: return user else: diff --git a/app/api/routes/api_auth.py b/app/api/routes/api_auth.py index 6b7c232..4a68748 100644 --- a/app/api/routes/api_auth.py +++ b/app/api/routes/api_auth.py @@ -6,21 +6,21 @@ import requests from fastapi import APIRouter, Depends, HTTPException, Request, status from jose import jwt -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies import ( get_current_active_user, get_current_active_user_from_clerk, ) from app.api.schemas.user import User -from app.core.database import get_db +from app.core.database import get_async_db from app.models.user import User as UserModel router = APIRouter() async def get_user_from_any_auth( - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), jwt_user: UserModel | None = Depends(get_current_active_user), clerk_user: UserModel | None = Depends(get_current_active_user_from_clerk), ) -> UserModel: @@ -45,7 +45,7 @@ async def get_user_from_any_auth( @router.get("/me", response_model=User) -def get_unified_current_user( +async def get_unified_current_user( current_user: UserModel = Depends(get_user_from_any_auth), ) -> Any: """ diff --git a/app/api/routes/api_keys.py b/app/api/routes/api_keys.py index 723a34e..27f1642 100644 --- a/app/api/routes/api_keys.py +++ b/app/api/routes/api_keys.py @@ -1,7 +1,9 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from app.api.dependencies import ( get_current_active_user, @@ -14,7 +16,7 @@ ForgeApiKeyUpdate, ) from app.core.async_cache import invalidate_forge_scope_cache_async, invalidate_user_cache_async, invalidate_provider_service_cache_async -from app.core.database import get_db +from app.core.database import get_async_db from app.core.security import generate_forge_api_key from app.models.forge_api_key import ForgeApiKey from app.models.provider_key import ProviderKey as ProviderKeyModel @@ -26,15 +28,17 @@ async def _get_api_keys_internal( - db: Session, current_user: UserModel + db: AsyncSession, current_user: UserModel ) -> list[ForgeApiKeyMasked]: """ Internal logic to get all API keys for the current user. """ - api_keys_query = db.query(ForgeApiKey).filter( - ForgeApiKey.user_id == current_user.id + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) + .filter(ForgeApiKey.user_id == current_user.id) ) - api_keys = api_keys_query.all() + api_keys = result.scalars().all() masked_keys = [] for api_key_db in api_keys: @@ -48,7 +52,7 @@ async def _get_api_keys_internal( async def _create_api_key_internal( - api_key_create: ForgeApiKeyCreate, db: Session, current_user: UserModel + api_key_create: ForgeApiKeyCreate, db: AsyncSession, current_user: UserModel ) -> ForgeApiKeyResponse: """ Internal logic to create a new API key for the current user. @@ -63,14 +67,13 @@ async def _create_api_key_internal( if api_key_create.allowed_provider_key_ids is not None: allowed_providers = [] if api_key_create.allowed_provider_key_ids: - allowed_providers = ( - db.query(ProviderKeyModel) - .filter( + result = await db.execute( + select(ProviderKeyModel).filter( ProviderKeyModel.id.in_(api_key_create.allowed_provider_key_ids), ProviderKeyModel.user_id == current_user.id, ) - .all() ) + allowed_providers = result.scalars().all() if len(allowed_providers) != len( set(api_key_create.allowed_provider_key_ids) ): @@ -81,8 +84,8 @@ async def _create_api_key_internal( db_api_key.allowed_provider_keys = allowed_providers db.add(db_api_key) - db.commit() - db.refresh(db_api_key) + await db.commit() + await db.refresh(db_api_key) response_data = db_api_key.__dict__.copy() response_data["allowed_provider_key_ids"] = [ @@ -92,16 +95,18 @@ async def _create_api_key_internal( async def _update_api_key_internal( - key_id: int, api_key_update: ForgeApiKeyUpdate, db: Session, current_user: UserModel + key_id: int, api_key_update: ForgeApiKeyUpdate, db: AsyncSession, current_user: UserModel ) -> ForgeApiKeyResponse: """ Internal logic to update an API key for the current user. """ - db_api_key = ( - db.query(ForgeApiKey) + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id) - .first() ) + db_api_key = result.scalar_one_or_none() + if not db_api_key: raise HTTPException(status_code=404, detail="API key not found") @@ -117,14 +122,13 @@ async def _update_api_key_internal( if api_key_update.allowed_provider_key_ids is not None: db_api_key.allowed_provider_keys.clear() if api_key_update.allowed_provider_key_ids: - allowed_providers = ( - db.query(ProviderKeyModel) - .filter( + result = await db.execute( + select(ProviderKeyModel).filter( ProviderKeyModel.id.in_(api_key_update.allowed_provider_key_ids), ProviderKeyModel.user_id == current_user.id, ) - .all() ) + allowed_providers = result.scalars().all() if len(allowed_providers) != len( set(api_key_update.allowed_provider_key_ids) ): @@ -134,8 +138,8 @@ async def _update_api_key_internal( ) db_api_key.allowed_provider_keys.extend(allowed_providers) - db.commit() - db.refresh(db_api_key) + await db.commit() + await db.refresh(db_api_key) # Invalidate forge scope cache if the scope was updated if api_key_update.allowed_provider_key_ids is not None: @@ -149,16 +153,18 @@ async def _update_api_key_internal( async def _delete_api_key_internal( - key_id: int, db: Session, current_user: UserModel + key_id: int, db: AsyncSession, current_user: UserModel ) -> ForgeApiKeyResponse: """ Internal logic to delete an API key for the current user. """ - db_api_key = ( - db.query(ForgeApiKey) + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id) - .first() ) + db_api_key = result.scalar_one_or_none() + if not db_api_key: raise HTTPException(status_code=404, detail="API key not found") @@ -174,8 +180,8 @@ async def _delete_api_key_internal( "allowed_provider_key_ids": [pk.id for pk in db_api_key.allowed_provider_keys], } - db.delete(db_api_key) - db.commit() + await db.delete(db_api_key) + await db.commit() await invalidate_user_cache_async(key_to_invalidate) await invalidate_forge_scope_cache_async(key_to_invalidate) @@ -184,16 +190,18 @@ async def _delete_api_key_internal( async def _regenerate_api_key_internal( - key_id: int, db: Session, current_user: UserModel + key_id: int, db: AsyncSession, current_user: UserModel ) -> ForgeApiKeyResponse: """ Internal logic to regenerate an API key for the current user while preserving other settings. """ - db_api_key = ( - db.query(ForgeApiKey) + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id) - .first() ) + db_api_key = result.scalar_one_or_none() + if not db_api_key: raise HTTPException(status_code=404, detail="API key not found") @@ -207,8 +215,8 @@ async def _regenerate_api_key_internal( new_key_value = generate_forge_api_key() db_api_key.key = new_key_value - db.commit() - db.refresh(db_api_key) + await db.commit() + await db.refresh(db_api_key) response_data = db_api_key.__dict__.copy() response_data["allowed_provider_key_ids"] = [ @@ -222,7 +230,7 @@ async def _regenerate_api_key_internal( @router.get("/", response_model=list[ForgeApiKeyMasked]) async def get_api_keys( - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: return await _get_api_keys_internal(db, current_user) @@ -231,7 +239,7 @@ async def get_api_keys( @router.post("/", response_model=ForgeApiKeyResponse) async def create_api_key( api_key_create: ForgeApiKeyCreate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: return await _create_api_key_internal(api_key_create, db, current_user) @@ -241,7 +249,7 @@ async def create_api_key( async def update_api_key( key_id: int, api_key_update: ForgeApiKeyUpdate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: return await _update_api_key_internal(key_id, api_key_update, db, current_user) @@ -250,7 +258,7 @@ async def update_api_key( @router.delete("/{key_id}", response_model=ForgeApiKeyResponse) async def delete_api_key( key_id: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: return await _delete_api_key_internal(key_id, db, current_user) @@ -259,7 +267,7 @@ async def delete_api_key( @router.post("/{key_id}/regenerate", response_model=ForgeApiKeyResponse) async def regenerate_api_key( key_id: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: return await _regenerate_api_key_internal(key_id, db, current_user) @@ -268,7 +276,7 @@ async def regenerate_api_key( # Clerk versions of the routes @router.get("/clerk", response_model=list[ForgeApiKeyMasked]) async def get_api_keys_clerk( - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: return await _get_api_keys_internal(db, current_user) @@ -277,7 +285,7 @@ async def get_api_keys_clerk( @router.post("/clerk", response_model=ForgeApiKeyResponse) async def create_api_key_clerk( api_key_create: ForgeApiKeyCreate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: return await _create_api_key_internal(api_key_create, db, current_user) @@ -287,7 +295,7 @@ async def create_api_key_clerk( async def update_api_key_clerk( key_id: int, api_key_update: ForgeApiKeyUpdate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: return await _update_api_key_internal(key_id, api_key_update, db, current_user) @@ -296,7 +304,7 @@ async def update_api_key_clerk( @router.delete("/clerk/{key_id}", response_model=ForgeApiKeyResponse) async def delete_api_key_clerk( key_id: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: return await _delete_api_key_internal(key_id, db, current_user) @@ -305,7 +313,7 @@ async def delete_api_key_clerk( @router.post("/clerk/{key_id}/regenerate", response_model=ForgeApiKeyResponse) async def regenerate_api_key_clerk( key_id: int, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: return await _regenerate_api_key_internal(key_id, db, current_user) diff --git a/app/api/routes/auth.py b/app/api/routes/auth.py index df80cd5..c81ad21 100644 --- a/app/api/routes/auth.py +++ b/app/api/routes/auth.py @@ -3,11 +3,12 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordRequestForm -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from app.api.routes.users import create_user as create_user_endpoint_logic from app.api.schemas.user import Token, User, UserCreate -from app.core.database import get_db +from app.core.database import get_async_db from app.core.logger import get_logger from app.core.security import ( ACCESS_TOKEN_EXPIRE_MINUTES, @@ -22,9 +23,9 @@ @router.post("/register", response_model=User) -def register( +async def register( user_in: UserCreate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: """ Register new user. This will create the user but will not automatically create a Forge API key. @@ -33,7 +34,7 @@ def register( # Call the user creation logic from users.py # This handles checks for existing email/username and password hashing. try: - db_user = create_user_endpoint_logic(user_in=user_in, db=db) + db_user = await create_user_endpoint_logic(user_in=user_in, db=db) except HTTPException as e: # Propagate HTTPExceptions (like 400 for existing user) raise e except Exception as e: # Catch any other unexpected errors during user creation @@ -73,13 +74,18 @@ def register( @router.post("/token", response_model=Token) -def login_for_access_token( - db: Session = Depends(get_db), form_data: OAuth2PasswordRequestForm = Depends() +async def login_for_access_token( + db: AsyncSession = Depends(get_async_db), + form_data: OAuth2PasswordRequestForm = Depends() ) -> Any: """ Get an access token for future API requests. """ - user = db.query(UserModel).filter(UserModel.username == form_data.username).first() + result = await db.execute( + select(UserModel).filter(UserModel.username == form_data.username) + ) + user = result.scalar_one_or_none() + if not user or not verify_password(form_data.password, user.hashed_password): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/app/api/routes/provider_keys.py b/app/api/routes/provider_keys.py index 1aff50d..523fd1a 100644 --- a/app/api/routes/provider_keys.py +++ b/app/api/routes/provider_keys.py @@ -2,7 +2,8 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from starlette import status from app.api.dependencies import ( @@ -15,13 +16,14 @@ ProviderKeyUpdate, ProviderKeyUpsertItem, ) -from app.core.cache import invalidate_provider_service_cache -from app.core.database import get_db +from app.core.async_cache import invalidate_provider_service_cache_async +from app.core.database import get_async_db from app.core.logger import get_logger from app.core.security import decrypt_api_key, encrypt_api_key from app.models.provider_key import ProviderKey as ProviderKeyModel -from app.models.user import User +from app.models.user import User as UserModel from app.services.providers.adapter_factory import ProviderAdapterFactory +from app.services.providers.base import ProviderAdapter logger = get_logger(name="provider_keys") @@ -29,247 +31,280 @@ # --- Internal Service Functions --- +def _validate_provider_cls_init(provider_name: str, base_url: str, config: dict[str, Any]) -> ProviderAdapter: + provider_cls = ProviderAdapterFactory.get_adapter_cls(provider_name) + try: + provider_cls(provider_name, base_url, config=config) + except Exception as e: + logger.error({ + "message": f"Error initializing provider {provider_name}", + "extra":{ + "error": str(e), + } + }) + raise HTTPException( + status_code=400, + detail=f"Error initializing provider {provider_name}", + ) + return provider_cls + -def _get_provider_keys_internal( - db: Session, current_user: User -) -> list[ProviderKeyModel]: - """Internal. Retrieve all provider keys for the current user.""" - return ( - db.query(ProviderKeyModel) - .filter(ProviderKeyModel.user_id == current_user.id) - .all() +async def _get_provider_keys_internal( + db: AsyncSession, current_user: UserModel +) -> list[ProviderKey]: + """ + Internal logic to get all provider keys for the current user. + """ + result = await db.execute( + select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id) ) + provider_keys = result.scalars().all() + return [ProviderKey.model_validate(pk) for pk in provider_keys] -def _create_provider_key_internal( - provider_key_in: ProviderKeyCreate, db: Session, current_user: User +async def _process_provider_key_create_data( + db: AsyncSession, + provider_key_create: ProviderKeyCreate, + user_id: int, ) -> ProviderKeyModel: - """Internal. Create a new provider key.""" - existing_key = ( - db.query(ProviderKeyModel) - .filter( + provider_name = provider_key_create.provider_name + provider_cls = _validate_provider_cls_init(provider_name, provider_key_create.base_url, provider_key_create.config) + serialized_api_key_config = provider_cls.serialize_api_key_config(provider_key_create.api_key, provider_key_create.config) + + encrypted_key = encrypt_api_key(serialized_api_key_config) + db_provider_key = ProviderKeyModel( + user_id=user_id, + provider_name=provider_name, + encrypted_api_key=encrypted_key, + base_url=provider_key_create.base_url, + model_mapping=provider_key_create.model_mapping, + ) + db.add(db_provider_key) + return db_provider_key + + +async def _create_provider_key_internal( + provider_key_create: ProviderKeyCreate, db: AsyncSession, current_user: UserModel +) -> ProviderKey: + """ + Internal logic to create a new provider key for the current user. + """ + # Check if provider already exists for user + result = await db.execute( + select(ProviderKeyModel).filter( ProviderKeyModel.user_id == current_user.id, - ProviderKeyModel.provider_name == provider_key_in.provider_name, + ProviderKeyModel.provider_name == provider_key_create.provider_name, ) - .first() ) + existing_key = result.scalar_one_or_none() + if existing_key: raise HTTPException( status_code=400, - detail=f"A key for provider {provider_key_in.provider_name} already exists", + detail=f"Provider key for {provider_key_create.provider_name} already exists", ) + + db_provider_key = await _process_provider_key_create_data(db, provider_key_create, current_user.id) + await db.commit() + await db.refresh(db_provider_key) - model_mapping_json = ( - json.dumps(provider_key_in.model_mapping) - if provider_key_in.model_mapping - else None - ) + # Invalidate caches after creating a new provider key + await invalidate_provider_service_cache_async(current_user.id) - provider_name = provider_key_in.provider_name - provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls(provider_name) + return ProviderKey.model_validate(db_provider_key) - # try to initialize the provider adapter - try: - provider_adapter_cls( - provider_name, provider_key_in.base_url, config=provider_key_in.config - ) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Error initializing provider {provider_name}: {e}", - ) - serialized_api_key_config = provider_adapter_cls.serialize_api_key_config( - provider_key_in.api_key, provider_key_in.config - ) +async def _process_provider_key_update_data( + db_provider_key: ProviderKeyModel, + provider_key_update: ProviderKeyUpdate, +) -> ProviderKeyModel: + update_data = provider_key_update.model_dump(exclude_unset=True) + provider_cls = ProviderAdapterFactory.get_adapter_cls(db_provider_key.provider_name) + old_api_key, old_config = provider_cls.deserialize_api_key_config(decrypt_api_key(db_provider_key.encrypted_api_key)) - provider_key = ProviderKeyModel( - provider_name=provider_name, - encrypted_api_key=encrypt_api_key(serialized_api_key_config), - user_id=current_user.id, - base_url=provider_key_in.base_url, - model_mapping=model_mapping_json, - ) - db.add(provider_key) - db.commit() - db.refresh(provider_key) - invalidate_provider_service_cache(current_user.id) - return provider_key + if "api_key" in update_data or "config" in update_data: + api_key = update_data.pop("api_key", None) or old_api_key + config = update_data.pop("config", None) or old_config + _validate_provider_cls_init(db_provider_key.provider_name, db_provider_key.base_url, config) + serialized_api_key_config = provider_cls.serialize_api_key_config(api_key, config) + update_data['encrypted_api_key'] = encrypt_api_key(serialized_api_key_config) + + for field, value in update_data.items(): + setattr(db_provider_key, field, value) + + return db_provider_key -def _update_provider_key_internal( +async def _update_provider_key_internal( provider_name: str, - provider_key_in: ProviderKeyUpdate, - db: Session, - current_user: User, -) -> ProviderKeyModel: - """Internal. Update a provider key.""" - provider_key = ( - db.query(ProviderKeyModel) - .filter( - ProviderKeyModel.user_id == current_user.id, + provider_key_update: ProviderKeyUpdate, + db: AsyncSession, + current_user: UserModel, +) -> ProviderKey: + """ + Internal logic to update a provider key for the current user. + """ + result = await db.execute( + select(ProviderKeyModel).filter( ProviderKeyModel.provider_name == provider_name, + ProviderKeyModel.user_id == current_user.id, ) - .first() ) - if not provider_key: - raise HTTPException( - status_code=404, - detail=f"Provider key for {provider_name} not found", - ) + db_provider_key = result.scalar_one_or_none() + + if not db_provider_key: + raise HTTPException(status_code=404, detail="Provider key not found") + + db_provider_key = await _process_provider_key_update_data(db_provider_key, provider_key_update) - # try to initialize the provider adapter if key info is provided - try: - provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls(provider_name) - _, old_config = provider_adapter_cls.deserialize_api_key_config( - decrypt_api_key(provider_key.encrypted_api_key) - ) - if provider_key_in.api_key or provider_key_in.config: - serialized_api_key_config = provider_adapter_cls.serialize_api_key_config( - provider_key_in.api_key, provider_key_in.config - ) - provider_key.encrypted_api_key = encrypt_api_key(serialized_api_key_config) - if provider_key_in.base_url is not None: - provider_key.base_url = provider_key_in.base_url - if provider_key_in.model_mapping is not None: - provider_key.model_mapping = json.dumps(provider_key_in.model_mapping) - - provider_adapter_cls( - provider_name, - provider_key.base_url, - config=provider_key_in.config or old_config, - ) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Error initializing provider {provider_name}: {e}", - ) + await db.commit() + await db.refresh(db_provider_key) + + # Invalidate caches after updating a provider key + await invalidate_provider_service_cache_async(current_user.id) - db.commit() - db.refresh(provider_key) - invalidate_provider_service_cache(current_user.id) - return provider_key + return ProviderKey.model_validate(db_provider_key) -def _delete_provider_key_internal( - provider_name: str, db: Session, current_user: User +async def _process_provider_key_delete_data( + db: AsyncSession, + provider_name: str, + user_id: int, ) -> ProviderKeyModel: - """Internal. Delete a provider key.""" - provider_key = ( - db.query(ProviderKeyModel) - .filter( - ProviderKeyModel.user_id == current_user.id, + result = await db.execute( + select(ProviderKeyModel).filter( ProviderKeyModel.provider_name == provider_name, + ProviderKeyModel.user_id == user_id, ) - .first() ) - if not provider_key: - raise HTTPException( - status_code=404, - detail=f"Provider key for {provider_name} not found", - ) - db.delete(provider_key) - db.commit() - invalidate_provider_service_cache(current_user.id) - return provider_key + db_provider_key = result.scalar_one_or_none() + + if not db_provider_key: + raise HTTPException(status_code=404, detail="Provider key not found") + + # Store the provider key data before deletion + provider_key_data = ProviderKey.model_validate(db_provider_key) + + await db.delete(db_provider_key) + + return provider_key_data + + +async def _delete_provider_key_internal( + provider_name: str, db: AsyncSession, current_user: UserModel +) -> ProviderKey: + """ + Internal logic to delete a provider key for the current user. + """ + provider_key_data = await _process_provider_key_delete_data(db, provider_name, current_user.id) + await db.commit() + + # Invalidate caches after deleting a provider key + await invalidate_provider_service_cache_async(current_user.id) + + return provider_key_data + +# --- API Endpoints --- @router.get("/", response_model=list[ProviderKey]) -def get_provider_keys( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user), +async def get_provider_keys( + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _get_provider_keys_internal(db, current_user) + return await _get_provider_keys_internal(db, current_user) @router.post("/", response_model=ProviderKey) -def create_provider_key( - provider_key_in: ProviderKeyCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user), +async def create_provider_key( + provider_key_create: ProviderKeyCreate, + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _create_provider_key_internal(provider_key_in, db, current_user) + return await _create_provider_key_internal(provider_key_create, db, current_user) @router.put("/{provider_name}", response_model=ProviderKey) -def update_provider_key( +async def update_provider_key( provider_name: str, - provider_key_in: ProviderKeyUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user), + provider_key_update: ProviderKeyUpdate, + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _update_provider_key_internal( - provider_name, provider_key_in, db, current_user + return await _update_provider_key_internal( + provider_name, provider_key_update, db, current_user ) @router.delete("/{provider_name}", response_model=ProviderKey) -def delete_provider_key( +async def delete_provider_key( provider_name: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user), + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user), ) -> Any: - return _delete_provider_key_internal(provider_name, db, current_user) + return await _delete_provider_key_internal(provider_name, db, current_user) + + +# --- Clerk API Routes --- -# Clerk versions of the routes @router.get("/clerk", response_model=list[ProviderKey]) -def get_provider_keys_clerk( - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user_from_clerk), +async def get_provider_keys_clerk( + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _get_provider_keys_internal(db, current_user) + return await _get_provider_keys_internal(db, current_user) @router.post("/clerk", response_model=ProviderKey) -def create_provider_key_clerk( - provider_key_in: ProviderKeyCreate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user_from_clerk), +async def create_provider_key_clerk( + provider_key_create: ProviderKeyCreate, + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _create_provider_key_internal(provider_key_in, db, current_user) + return await _create_provider_key_internal(provider_key_create, db, current_user) @router.put("/clerk/{provider_name}", response_model=ProviderKey) -def update_provider_key_clerk( +async def update_provider_key_clerk( provider_name: str, - provider_key_in: ProviderKeyUpdate, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user_from_clerk), + provider_key_update: ProviderKeyUpdate, + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _update_provider_key_internal( - provider_name, provider_key_in, db, current_user + return await _update_provider_key_internal( + provider_name, provider_key_update, db, current_user ) @router.delete("/clerk/{provider_name}", response_model=ProviderKey) -def delete_provider_key_clerk( +async def delete_provider_key_clerk( provider_name: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user_from_clerk), + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: - return _delete_provider_key_internal(provider_name, db, current_user) + return await _delete_provider_key_internal(key_id, db, current_user) # --- Batch Upsert API Endpoint --- -def _batch_upsert_provider_keys_internal( +async def _batch_upsert_provider_keys_internal( items: list[ProviderKeyUpsertItem], - db: Session, - current_user: User, -) -> list[ProviderKeyModel]: + db: AsyncSession, + current_user: UserModel, +) -> list[ProviderKey]: """ Internal logic for batch creating or updating provider keys for the current user. """ processed_keys: list[ProviderKeyModel] = [] + processed: bool = False # 1. Fetch all existing keys for the user - existing_keys_query = ( - db.query(ProviderKeyModel) - .filter(ProviderKeyModel.user_id == current_user.id) - .all() + result = await db.execute( + select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id) ) + existing_keys_query = result.scalars().all() # 2. Map them by provider_name for efficient lookup existing_keys_map: dict[str, ProviderKeyModel] = { key.provider_name: key for key in existing_keys_query @@ -278,98 +313,23 @@ def _batch_upsert_provider_keys_internal( for item in items: if "****" in item.api_key: continue + try: + existing_provider_key: ProviderKeyModel | None = existing_keys_map.get(item.provider_name) + # Handle deletion if api_key is "DELETE" if item.api_key == "DELETE": - try: - _delete_provider_key_internal(item.provider_name, db, current_user) - except HTTPException as e: - if ( - e.status_code != status.HTTP_404_NOT_FOUND - ): # Ignore 404 errors for missing keys - raise - continue - - db_key_to_process: ProviderKeyModel | None = existing_keys_map.get( - item.provider_name - ) - - if db_key_to_process: # Update existing key - try: - # try to initialize the provider adapter if key info is provided - provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls( - item.provider_name - ) - _, old_config = provider_adapter_cls.deserialize_api_key_config( - decrypt_api_key(db_key_to_process.encrypted_api_key) - ) - if item.api_key or item.config: - serialized_api_key_config = ( - provider_adapter_cls.serialize_api_key_config( - item.api_key, item.config - ) - ) - db_key_to_process.encrypted_api_key = encrypt_api_key( - serialized_api_key_config - ) - if ( - item.base_url is not None - ): # Allows setting base_url to "" or null - db_key_to_process.base_url = item.base_url - if item.model_mapping is not None: - db_key_to_process.model_mapping = json.dumps(item.model_mapping) - elif ( - hasattr(item, "model_mapping") and item.model_mapping is None - ): # Explicitly clear if None - db_key_to_process.model_mapping = None - provider_adapter_cls( - item.provider_name, - db_key_to_process.base_url, - config=item.config or old_config, - ) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Error updating provider {item.provider_name}: {e}", - ) - # No need to db.add() as it's already tracked by the session + if existing_provider_key: + await _process_provider_key_delete_data(db, item.provider_name, current_user.id) + processed = True + elif existing_provider_key: # Update existing key + db_key_to_process = await _process_provider_key_update_data(existing_provider_key, ProviderKeyUpdate.model_validate(item)) + processed_keys.append(db_key_to_process) + processed = True else: # Create new key - if not item.api_key: - raise HTTPException( - status_code=400, - detail=f"api_key is required to create a new provider key for {item.provider_name}", - ) - model_mapping_json = ( - json.dumps(item.model_mapping) if item.model_mapping else None - ) - provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls( - item.provider_name - ) - # try to initialize the provider adapter - try: - provider_adapter_cls( - item.provider_name, item.base_url, config=item.config - ) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Error initializing provider {item.provider_name}: {e}", - ) - serialized_api_key_config = ( - provider_adapter_cls.serialize_api_key_config( - item.api_key, item.config - ) - ) - db_key_to_process = ProviderKeyModel( - provider_name=item.provider_name, - encrypted_api_key=encrypt_api_key(serialized_api_key_config), - user_id=current_user.id, - base_url=item.base_url, - model_mapping=model_mapping_json, - ) - db.add(db_key_to_process) - - processed_keys.append(db_key_to_process) + db_key_to_process = await _process_provider_key_create_data(db, ProviderKeyCreate.model_validate(item), current_user.id) + processed_keys.append(db_key_to_process) + processed = True except HTTPException as http_exc: # db.rollback() # Optional: rollback if any item fails, or handle partial success @@ -390,16 +350,15 @@ def _batch_upsert_provider_keys_internal( detail=f"An unexpected error occurred while processing '{item.provider_name}'.", ) - if processed_keys: + if processed: try: - db.commit() + await db.commit() for key in processed_keys: - db.refresh( - key - ) # Refresh each key to get DB-generated values like id, timestamps - invalidate_provider_service_cache(current_user.id) + await db.refresh(key) # Refresh each key to get DB-generated values like id, timestamps + processed_keys = [ProviderKey.model_validate(key) for key in processed_keys] + await invalidate_provider_service_cache_async(current_user.id) except Exception as e: - db.rollback() + await db.rollback() error_message_prefix = "Error during final commit/refresh in batch upsert" if hasattr(current_user, "email"): # Check if it's a full User object error_message_prefix += f" (User: {current_user.email})" @@ -412,24 +371,24 @@ def _batch_upsert_provider_keys_internal( @router.post("/batch-upsert", response_model=list[ProviderKey]) -def batch_upsert_provider_keys( +async def batch_upsert_provider_keys( items: list[ProviderKeyUpsertItem], - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user), + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user), ) -> Any: """ Batch create or update provider keys for the current user. """ - return _batch_upsert_provider_keys_internal(items, db, current_user) + return await _batch_upsert_provider_keys_internal(items, db, current_user) @router.post("/clerk/batch-upsert", response_model=list[ProviderKey]) -def batch_upsert_provider_keys_clerk( +async def batch_upsert_provider_keys_clerk( items: list[ProviderKeyUpsertItem], - db: Session = Depends(get_db), - current_user: User = Depends(get_current_active_user_from_clerk), + db: AsyncSession = Depends(get_async_db), + current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: """ Batch create or update provider keys for the current user (Clerk authenticated). """ - return _batch_upsert_provider_keys_internal(items, db, current_user) + return await _batch_upsert_provider_keys_internal(items, db, current_user) diff --git a/app/api/routes/proxy.py b/app/api/routes/proxy.py index 669973a..82384b6 100644 --- a/app/api/routes/proxy.py +++ b/app/api/routes/proxy.py @@ -2,7 +2,9 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException, Request -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from starlette.responses import StreamingResponse from app.api.dependencies import get_user_by_api_key @@ -15,8 +17,9 @@ ImageGenerationRequest, ) from app.core.async_cache import async_provider_service_cache -from app.core.database import get_db +from app.core.database import get_async_db from app.core.logger import get_logger +from app.models.forge_api_key import ForgeApiKey from app.models.user import User from app.services.provider_service import ProviderService @@ -29,7 +32,7 @@ # None → unrestricted, [] → explicitly no providers. # ------------------------------------------------------------- async def _get_allowed_provider_names( - request: Request, db: Session + request: Request, db: AsyncSession ) -> list[str] | None: api_key = getattr(request.state, "forge_api_key", None) if api_key is None: @@ -43,19 +46,15 @@ async def _get_allowed_provider_names( if allowed is not None: return allowed - from sqlalchemy.orm import joinedload - - from app.models.forge_api_key import ForgeApiKey - allowed = await async_provider_service_cache.get(f"forge_scope:{api_key}") if allowed is None: - forge_key = ( - db.query(ForgeApiKey) - .options(joinedload(ForgeApiKey.allowed_provider_keys)) + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) .filter(ForgeApiKey.key == f"forge-{api_key}", ForgeApiKey.is_active) - .first() ) + forge_key = result.scalar_one_or_none() if forge_key is None: raise HTTPException( status_code=401, detail="Forge API key not found or inactive" @@ -74,7 +73,7 @@ async def create_chat_completion( request: Request, chat_request: ChatCompletionRequest, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: """ Create a chat completion (OpenAI-compatible endpoint). @@ -123,7 +122,7 @@ async def create_completion( request: Request, completion_request: CompletionRequest, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: """ Create a completion (OpenAI-compatible endpoint). @@ -166,7 +165,7 @@ async def create_image_generation( request: Request, image_generation_request: ImageGenerationRequest, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: """ Create an image generation (OpenAI-compatible endpoint). @@ -197,7 +196,7 @@ async def create_image_edits( request: Request, image_edits_request: ImageEditsRequest, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: try: provider_service = await ProviderService.async_get_instance(user, db) @@ -221,7 +220,7 @@ async def create_image_edits( async def list_models( request: Request, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> dict[str, Any]: """ List available models. Only models from providers that are within the scope of the @@ -237,6 +236,7 @@ async def list_models( ) return {"object": "list", "data": models} except Exception as err: + logger.error(f"Error listing models: {str(err)}") raise HTTPException( status_code=500, detail=f"Error listing models: {str(err)}" ) from err @@ -248,7 +248,7 @@ async def create_embeddings( request: Request, embeddings_request: EmbeddingsRequest, user: User = Depends(get_user_by_api_key), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ) -> Any: """ Create embeddings (OpenAI-compatible endpoint). diff --git a/app/api/routes/stats.py b/app/api/routes/stats.py index 7aa39be..a470e07 100644 --- a/app/api/routes/stats.py +++ b/app/api/routes/stats.py @@ -2,12 +2,12 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException, Query -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.api.dependencies import ( get_current_active_user_from_clerk, get_current_user, - get_db, + get_async_db, get_user_by_api_key, ) from app.models.user import User @@ -29,7 +29,7 @@ async def get_user_stats( end_date: date | None = Query( None, description="End date for filtering (YYYY-MM-DD)" ), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ): """ Get aggregated usage statistics for the current user, queried from request logs. @@ -38,7 +38,7 @@ async def get_user_stats( """ # Note: Service layer now handles aggregation and filtering # We pass the query parameters directly to the service method - stats = UsageStatsService.get_user_stats( + stats = await UsageStatsService.get_user_stats( db=db, user_id=current_user.id, provider=provider, @@ -62,7 +62,7 @@ async def get_user_stats_clerk( end_date: date | None = Query( None, description="End date for filtering (YYYY-MM-DD)" ), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ): """ Get aggregated usage statistics for the current user, queried from request logs. @@ -71,7 +71,7 @@ async def get_user_stats_clerk( """ # Note: Service layer now handles aggregation and filtering # We pass the query parameters directly to the service method - stats = UsageStatsService.get_user_stats( + stats = await UsageStatsService.get_user_stats( db=db, user_id=current_user.id, provider=provider, @@ -93,7 +93,7 @@ async def get_all_stats( end_date: date | None = Query( None, description="End date for filtering (YYYY-MM-DD)" ), - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), ): """ Get aggregated usage statistics for all users, queried from request logs. @@ -106,7 +106,7 @@ async def get_all_stats( status_code=403, detail="Not authorized to access admin statistics" ) - stats = UsageStatsService.get_all_stats( + stats = await UsageStatsService.get_all_stats( db=db, provider=provider, model=model, start_date=start_date, end_date=end_date ) return stats diff --git a/app/api/routes/users.py b/app/api/routes/users.py index 9abdce9..52d84ae 100644 --- a/app/api/routes/users.py +++ b/app/api/routes/users.py @@ -1,17 +1,14 @@ from typing import Any from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy.orm import Session - -from app.api.dependencies import ( - get_current_active_user, - get_current_active_user_from_clerk, -) -from app.api.schemas.user import MaskedUser, User, UserCreate, UserUpdate -from app.core.cache import invalidate_user_cache -from app.core.database import get_db -from app.core.logger import get_logger +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.dependencies import get_current_active_user, get_current_active_user_from_clerk +from app.api.schemas.user import User, UserCreate, UserUpdate, MaskedUser +from app.core.database import get_async_db from app.core.security import get_password_hash +from app.core.logger import get_logger from app.models.user import User as UserModel from app.services.provider_service import create_default_tensorblock_provider_for_user @@ -20,53 +17,61 @@ router = APIRouter() -@router.post("/", response_model=User, status_code=201) -def create_user( - user_in: UserCreate, db: Session = Depends(get_db) -) -> Any: # pragma: no cover - (Covered by test_user_creation_and_login) +@router.post("/", response_model=User) +async def create_user( + user_in: UserCreate, db: AsyncSession = Depends(get_async_db) +) -> Any: """ - Create new user. + Create a new user. """ - db_user = db.query(UserModel).filter(UserModel.email == user_in.email).first() + # Check if email already exists + result = await db.execute( + select(UserModel).filter(UserModel.email == user_in.email) + ) + db_user = result.scalar_one_or_none() if db_user: raise HTTPException( - status_code=400, - detail="The user with this email already exists in the system.", + status_code=400, detail="Email already registered" ) - db_user = db.query(UserModel).filter(UserModel.username == user_in.username).first() + + # Check if username already exists + result = await db.execute( + select(UserModel).filter(UserModel.username == user_in.username) + ) + db_user = result.scalar_one_or_none() if db_user: raise HTTPException( - status_code=400, - detail="The user with this username already exists in the system.", + status_code=400, detail="Username already registered" ) - + + # Create new user hashed_password = get_password_hash(user_in.password) db_user = UserModel( - username=user_in.username, email=user_in.email, + username=user_in.username, hashed_password=hashed_password, - # Removed automatic API key generation on user creation - # api_key=generate_forge_api_key(), # Users will create keys via /api-keys endpoint - is_active=True, # Default to active, admin can deactivate ) db.add(db_user) - db.commit() - db.refresh(db_user) + await db.commit() + await db.refresh(db_user) # Create default TensorBlock provider for the new user try: - create_default_tensorblock_provider_for_user(db_user.id, db) + await create_default_tensorblock_provider_for_user(db_user.id, db) except Exception as e: # Log error but don't fail user creation - logger.warning( - f"Failed to create default TensorBlock provider for user {db_user.id}: {e}" - ) + logger.error({ + "message": f"Error creating default TensorBlock provider for user {db_user.id}", + "extra": { + "error": str(e), + } + }) return db_user @router.get("/me", response_model=MaskedUser) -def read_user_me( +async def read_user_me( current_user: UserModel = Depends(get_current_active_user), ) -> Any: """ @@ -85,7 +90,7 @@ def read_user_me( @router.get("/me/clerk", response_model=MaskedUser) -def read_user_me_clerk( +async def read_user_me_clerk( current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: """ @@ -103,9 +108,9 @@ def read_user_me_clerk( @router.put("/me", response_model=User) -def update_user_me( +async def update_user_me( user_in: UserUpdate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user), ) -> Any: """ @@ -119,8 +124,8 @@ def update_user_me( current_user.hashed_password = get_password_hash(user_in.password) db.add(current_user) - db.commit() - db.refresh(current_user) + await db.commit() + await db.refresh(current_user) invalidate_user_cache( current_user.id ) # Assuming user_id is the cache key for user object @@ -131,15 +136,15 @@ def update_user_me( @router.put("/me/clerk", response_model=User) -def update_user_me_clerk( +async def update_user_me_clerk( user_in: UserUpdate, - db: Session = Depends(get_db), + db: AsyncSession = Depends(get_async_db), current_user: UserModel = Depends(get_current_active_user_from_clerk), ) -> Any: """ Update current user from Clerk. """ - return update_user_me(user_in, db, current_user) + return await update_user_me(user_in, db, current_user) # The regenerate_api_key and regenerate_api_key_clerk endpoints have been removed. diff --git a/app/api/routes/webhooks.py b/app/api/routes/webhooks.py index ccfaa8a..2dac064 100644 --- a/app/api/routes/webhooks.py +++ b/app/api/routes/webhooks.py @@ -1,12 +1,14 @@ import json import os +from typing import Any from fastapi import APIRouter, Depends, HTTPException, Request, status +from sqlalchemy import select from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session -from svix.webhooks import Webhook, WebhookVerificationError +from sqlalchemy.ext.asyncio import AsyncSession +from svix import Webhook, WebhookVerificationError -from app.core.database import get_db +from app.core.database import get_async_db from app.core.logger import get_logger from app.core.security import generate_forge_api_key from app.models.user import User @@ -21,7 +23,7 @@ @router.post("/clerk") -async def clerk_webhook_handler(request: Request, db: Session = Depends(get_db)): +async def clerk_webhook_handler(request: Request, db: AsyncSession = Depends(get_async_db)): """ Handle Clerk webhooks for user events. @@ -99,100 +101,13 @@ async def clerk_webhook_handler(request: Request, db: Session = Depends(get_db)) # Handle different event types if event_type == "user.created": - # Check if user already exists - user = db.query(User).filter(User.clerk_user_id == clerk_user_id).first() - if user: - return {"status": "success", "message": "User already exists"} - - # Check if user exists with this email - existing_user = db.query(User).filter(User.email == email).first() - if existing_user: - # Link existing user to Clerk ID - try: - existing_user.clerk_user_id = clerk_user_id - db.commit() - return {"status": "success", "message": "Linked to existing user"} - except IntegrityError: - # Another request might have already linked this user or created a new one - db.rollback() - # Retry the query to get the user - user = ( - db.query(User) - .filter(User.clerk_user_id == clerk_user_id) - .first() - ) - if user: - return {"status": "success", "message": "User already exists"} - # If still no user, continue with creation attempt - - # Create new user - forge_api_key = generate_forge_api_key() - - try: - user = User( - email=email, - username=username, - clerk_user_id=clerk_user_id, - is_active=True, - forge_api_key=forge_api_key, - ) - db.add(user) - db.commit() - - # Create default TensorBlock provider for the new user - try: - create_default_tensorblock_provider_for_user(user.id, db) - except Exception as e: - # Log error but don't fail user creation - logger.warning( - f"Failed to create default TensorBlock provider for user {user.id}: {e}" - ) - - return {"status": "success", "message": "User created"} - except IntegrityError as e: - # Handle race condition: another request might have created the user - db.rollback() - if "users_clerk_user_id_key" in str(e) or "clerk_user_id" in str(e): - # Retry the query to get the user that was created by another request - user = ( - db.query(User) - .filter(User.clerk_user_id == clerk_user_id) - .first() - ) - if user: - return {"status": "success", "message": "User already exists"} - else: - # This shouldn't happen, but handle it gracefully - return { - "status": "error", - "message": "Failed to create user due to database constraint", - } - else: - # Re-raise other integrity errors - raise + await handle_user_created(event_data, db) elif event_type == "user.updated": - # Update user if they exist - user = db.query(User).filter(User.clerk_user_id == clerk_user_id).first() - if not user: - return {"status": "error", "message": "User not found"} - - # Update fields - if email and user.email != email: - user.email = email - if username and user.username != username: - user.username = username - - db.commit() - return {"status": "success", "message": "User updated"} + await handle_user_updated(event_data, db) elif event_type == "user.deleted": - # Deactivate user rather than delete - user = db.query(User).filter(User.clerk_user_id == clerk_user_id).first() - if user: - user.is_active = False - db.commit() - return {"status": "success", "message": "User deactivated"} + await handle_user_deleted(event_data, db) return {"status": "success", "message": f"Event {event_type} processed"} @@ -205,3 +120,121 @@ async def clerk_webhook_handler(request: Request, db: Session = Depends(get_db)) status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error processing webhook: {str(e)}", ) + + +async def handle_user_created(event_data: dict, db: AsyncSession): + """Handle user.created event from Clerk""" + try: + clerk_user_id = event_data.get("id") + email = event_data.get("email_addresses", [{}])[0].get("email_address", "") + username = ( + event_data.get("username") + or event_data.get("first_name", "") + or email.split("@")[0] + ) + + logger.info(f"Creating user from Clerk webhook: {username} ({email})") + + # Check if user already exists by clerk_user_id + result = await db.execute( + select(User).filter(User.clerk_user_id == clerk_user_id) + ) + user = result.scalar_one_or_none() + if user: + logger.info(f"User {username} already exists with Clerk ID") + return + + # Check if user exists with this email + result = await db.execute( + select(User).filter(User.email == email) + ) + existing_user = result.scalar_one_or_none() + if existing_user: + # Link existing user to Clerk ID + existing_user.clerk_user_id = clerk_user_id + await db.commit() + logger.info(f"Linked existing user {existing_user.username} to Clerk ID") + return + + # Create new user + user = User( + username=username, + email=email, + clerk_user_id=clerk_user_id, + is_active=True, + hashed_password="", # Clerk handles authentication + ) + db.add(user) + await db.commit() + await db.refresh(user) + + # Create default provider for the user + create_default_tensorblock_provider_for_user(user.id, db) + + logger.info(f"Successfully created user {username} with ID {user.id}") + + except Exception as e: + await db.rollback() + logger.error(f"Failed to create user from webhook: {e}", exc_info=True) + raise + + +async def handle_user_updated(event_data: dict, db: AsyncSession): + """Handle user.updated event from Clerk""" + try: + clerk_user_id = event_data.get("id") + email = event_data.get("email_addresses", [{}])[0].get("email_address", "") + username = ( + event_data.get("username") + or event_data.get("first_name", "") + or email.split("@")[0] + ) + + logger.info(f"Updating user from Clerk webhook: {username} ({email})") + + result = await db.execute( + select(User).filter(User.clerk_user_id == clerk_user_id) + ) + user = result.scalar_one_or_none() + if not user: + logger.warning(f"User with Clerk ID {clerk_user_id} not found for update") + return + + # Update user information + user.username = username + user.email = email + await db.commit() + + logger.info(f"Successfully updated user {username}") + + except Exception as e: + await db.rollback() + logger.error(f"Failed to update user from webhook: {e}", exc_info=True) + raise + + +async def handle_user_deleted(event_data: dict, db: AsyncSession): + """Handle user.deleted event from Clerk""" + try: + clerk_user_id = event_data.get("id") + + logger.info(f"Deleting user from Clerk webhook: {clerk_user_id}") + + result = await db.execute( + select(User).filter(User.clerk_user_id == clerk_user_id) + ) + user = result.scalar_one_or_none() + if not user: + logger.warning(f"User with Clerk ID {clerk_user_id} not found for deletion") + return + + # Deactivate user instead of deleting to preserve data integrity + user.is_active = False + await db.commit() + + logger.info(f"Successfully deactivated user {user.username}") + + except Exception as e: + await db.rollback() + logger.error(f"Failed to delete user from webhook: {e}", exc_info=True) + raise diff --git a/app/api/schemas/provider_key.py b/app/api/schemas/provider_key.py index aaaac2f..17dbc57 100644 --- a/app/api/schemas/provider_key.py +++ b/app/api/schemas/provider_key.py @@ -64,6 +64,7 @@ class ProviderKeyCreate(ProviderKeyBase): class ProviderKeyUpdate(BaseModel): api_key: str | None = None + config: dict[str, str] | None = None base_url: str | None = None model_mapping: dict[str, str] | None = None diff --git a/app/core/async_cache.py b/app/core/async_cache.py index a3e34e6..c3ba536 100644 --- a/app/core/async_cache.py +++ b/app/core/async_cache.py @@ -10,7 +10,8 @@ from collections.abc import Callable from typing import Any, TypeVar -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select from app.api.schemas.cached_user import CachedUser from app.core.logger import get_logger @@ -391,7 +392,7 @@ async def invalidate_all_caches_async() -> None: logger.debug("Cache: Invalidated all caches") -async def warm_cache_async(db: Session) -> None: +async def warm_cache_async(db: AsyncSession) -> None: """Pre-cache frequently accessed data asynchronously""" from app.models.user import User from app.services.provider_service import ProviderService @@ -400,14 +401,17 @@ async def warm_cache_async(db: Session) -> None: logger.info("Cache: Starting cache warm-up...") # Cache active users - active_users = db.query(User).filter(User.is_active).all() + result = await db.execute(select(User).filter(User.is_active)) + active_users = result.scalars().all() + for user in active_users: # Get user's Forge API keys - forge_api_keys = ( - db.query(ForgeApiKey) + result = await db.execute( + select(ForgeApiKey) .filter(ForgeApiKey.user_id == user.id, ForgeApiKey.is_active) - .all() ) + forge_api_keys = result.scalars().all() + for key in forge_api_keys: await cache_user_async(key.key, user) diff --git a/app/core/cache.py b/app/core/cache.py index 26f5693..7fe9fbe 100644 --- a/app/core/cache.py +++ b/app/core/cache.py @@ -1,3 +1,4 @@ +# TODO: deprecate this and move to async cache import functools import os import time @@ -352,11 +353,6 @@ async def warm_cache(db: Session) -> None: # Cache user with their Forge API key cache_user(key.key, user) - # Cache provider services for active users - for user in active_users: - service = ProviderService.get_instance(user, db) - cache_provider_service(user.id, service) - if DEBUG_CACHE: logger.info(f"Cache: Warm-up complete. Cached {len(active_users)} users") diff --git a/app/core/database.py b/app/core/database.py index 92cf992..5253f03 100644 --- a/app/core/database.py +++ b/app/core/database.py @@ -1,31 +1,87 @@ import os from dotenv import load_dotenv +from contextlib import asynccontextmanager from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.orm import declarative_base, sessionmaker load_dotenv() +POOL_SIZE = 5 +MAX_OVERFLOW = 10 +MAX_TIMEOUT = 30 +POOL_RECYCLE = 1800 + SQLALCHEMY_DATABASE_URL = os.getenv("DATABASE_URL") if not SQLALCHEMY_DATABASE_URL: raise ValueError("DATABASE_URL environment variable is not set") +# Sync engine and session engine = create_engine( SQLALCHEMY_DATABASE_URL, - pool_size=5, - max_overflow=10, - pool_timeout=30, - pool_recycle=1800, + pool_size=POOL_SIZE, + max_overflow=MAX_OVERFLOW, + pool_timeout=MAX_TIMEOUT, + pool_recycle=POOL_RECYCLE, + echo=False, ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) -Base = declarative_base() - -# Dependency +# Sync dependency def get_db(): db = SessionLocal() try: yield db finally: db.close() + + +# Async engine and session (new) +# Convert the DATABASE_URL to async format if it's using psycopg2 +ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL +if SQLALCHEMY_DATABASE_URL.startswith("postgresql://"): + ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://") +elif SQLALCHEMY_DATABASE_URL.startswith("postgresql+psycopg2://"): + ASYNC_DATABASE_URL = SQLALCHEMY_DATABASE_URL.replace("postgresql+psycopg2://", "postgresql+asyncpg://") + +async_engine = create_async_engine( + ASYNC_DATABASE_URL, + pool_size=POOL_SIZE, + max_overflow=MAX_OVERFLOW, + pool_timeout=MAX_TIMEOUT, + pool_recycle=POOL_RECYCLE, + echo=False, +) + +AsyncSessionLocal = async_sessionmaker( + bind=async_engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, +) + +Base = declarative_base() + + +# Async dependency +async def get_async_db(): + async with AsyncSessionLocal() as session: + try: + yield session + finally: + await session.close() + + +@asynccontextmanager +async def get_db_session(): + """Async context manager for database sessions""" + async with AsyncSessionLocal() as session: + try: + yield session + except Exception: + await session.rollback() + raise + finally: + await session.close() \ No newline at end of file diff --git a/app/models/provider_key.py b/app/models/provider_key.py index 9fb4b2b..eb191e3 100644 --- a/app/models/provider_key.py +++ b/app/models/provider_key.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, ForeignKey, Integer, String +from sqlalchemy import Column, ForeignKey, Integer, String, JSON from sqlalchemy.orm import relationship from app.models.forge_api_key import forge_api_key_provider_scope_association @@ -18,7 +18,7 @@ class ProviderKey(BaseModel): base_url = Column( String, nullable=True ) # Allow custom base URLs for some providers - model_mapping = Column(String, nullable=True) # JSON string for model name mappings + model_mapping = Column(JSON, nullable=True) # JSON dict for model name mappings # Relationship to ForgeApiKeys that have this provider key in their scope scoped_forge_api_keys = relationship( diff --git a/app/services/provider_service.py b/app/services/provider_service.py index 61688e3..281050f 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -6,13 +6,10 @@ from collections.abc import AsyncGenerator from typing import Any, ClassVar -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession -# For async support -from app.core.cache import ( - DEBUG_CACHE, - provider_service_cache, -) +from app.core.async_cache import async_provider_service_cache, DEBUG_CACHE from app.core.logger import get_logger from app.core.security import decrypt_api_key, encrypt_api_key from app.exceptions.exceptions import InvalidProviderException, BaseInvalidRequestException, InvalidForgeKeyException @@ -53,7 +50,7 @@ class method rather than direct instantiation. # ------------------------------------------------------------------ # Helper for building a cache key that works across all workers. - # Stored via app.core.cache.provider_service_cache which resolves to + # Stored via app.core.async_cache.async_provider_service_cache which resolves to # either RedisCache or in-memory Cache. # ------------------------------------------------------------------ @@ -62,7 +59,7 @@ def _model_cache_key(cls, provider_name: str, cache_key: str) -> str: # Using a stable namespace makes invalidation easier return f"models:{provider_name}:{cache_key}" - def __init__(self, user_id: int, db: Session): + def __init__(self, user_id: int, db: AsyncSession): self.user_id = user_id self.db = db self.provider_keys: dict[str, dict[str, Any]] = {} @@ -71,28 +68,7 @@ def __init__(self, user_id: int, db: Session): self._keys_loaded = False @classmethod - def get_instance(cls, user: User, db: Session) -> "ProviderService": - """Get a cached instance of ProviderService for a user or create a new one""" - cache_key = f"provider_service:{user.id}" - cached_instance = provider_service_cache.get(cache_key) - if cached_instance: - if DEBUG_CACHE: - logger.debug( - f"Using cached ProviderService instance for user {user.id}" - ) - # Update the db session reference for the cached instance - cached_instance.db = db - return cached_instance - - # No cached instance found, create a new one - if DEBUG_CACHE: - logger.debug(f"Creating new ProviderService instance for user {user.id}") - instance = cls(user.id, db) - provider_service_cache.set(cache_key, instance) - return instance - - @classmethod - async def async_get_instance(cls, user: User, db: Session) -> "ProviderService": + async def async_get_instance(cls, user: User, db: AsyncSession) -> "ProviderService": """Get a cached instance of ProviderService for a user or create a new one (async version)""" from app.core.async_cache import async_provider_service_cache @@ -117,7 +93,7 @@ async def async_get_instance(cls, user: User, db: Session) -> "ProviderService": return instance @classmethod - def get_cached_models( + async def get_cached_models( cls, provider_name: str, cache_key: str ) -> list[dict[str, Any]] | None: """Return cached model list if present (shared cache).""" @@ -131,7 +107,7 @@ def get_cached_models( return l1_entry[1] # -------- L2: shared cache (Redis / memory) -------- - models = provider_service_cache.get(key) + models = await async_provider_service_cache.get(key) if models: # populate L1 cls._models_l1_cache[key] = (time.time() + cls._models_cache_ttl, models) @@ -140,14 +116,14 @@ def get_cached_models( return models @classmethod - def cache_models( + async def cache_models( cls, provider_name: str, cache_key: str, models: list[dict[str, Any]] ) -> None: """Store models in the shared cache with a TTL.""" key = cls._model_cache_key(provider_name, cache_key) # Write to shared cache (L2) - provider_service_cache.set(key, models, ttl=cls._models_cache_ttl) + await async_provider_service_cache.set(key, models, ttl=cls._models_cache_ttl) # Populate/refresh L1 cls._models_l1_cache[key] = (time.time() + cls._models_cache_ttl, models) @@ -163,30 +139,14 @@ def _get_adapters(self) -> dict[str, ProviderAdapter]: ProviderService._adapters_cache = ProviderAdapterFactory.get_all_adapters() return ProviderService._adapters_cache - def _parse_model_mapping(self, mapping_str: str | None) -> dict: - if not mapping_str: - return {} - try: - return json.loads(mapping_str) - except json.JSONDecodeError: - logger.warning(f"Failed to parse model_mapping JSON: {mapping_str}") - # Try a literal eval as fallback - try: - import ast - - return ast.literal_eval(mapping_str) - except (SyntaxError, ValueError): - logger.warning(f"Could not parse model_mapping: {mapping_str}") - return {} - - def _load_provider_keys(self) -> dict[str, dict[str, Any]]: + async def _load_provider_keys(self) -> dict[str, dict[str, Any]]: """Load all provider keys for the user synchronously, with lazy loading and caching.""" if self._keys_loaded: return self.provider_keys # Try to get provider keys from cache cache_key = f"provider_keys:{self.user_id}" - cached_keys = provider_service_cache.get(cache_key) + cached_keys = await async_provider_service_cache.get(cache_key) if cached_keys is not None: if DEBUG_CACHE: logger.debug( @@ -204,13 +164,12 @@ def _load_provider_keys(self) -> dict[str, dict[str, Any]]: # Query ProviderKey directly by user_id from app.models.provider_key import ProviderKey - provider_key_records = ( - self.db.query(ProviderKey).filter(ProviderKey.user_id == self.user_id).all() - ) + result = await self.db.execute(select(ProviderKey).filter(ProviderKey.user_id == self.user_id)) + provider_key_records = result.scalars().all() keys = {} for provider_key in provider_key_records: - model_mapping = self._parse_model_mapping(provider_key.model_mapping) + model_mapping = provider_key.model_mapping or {} keys[provider_key.provider_name] = { "api_key": decrypt_api_key(provider_key.encrypted_api_key), @@ -226,7 +185,7 @@ def _load_provider_keys(self) -> dict[str, dict[str, Any]]: logger.debug( f"Caching provider keys for user {self.user_id} (TTL: 3600s) (sync)" ) - provider_service_cache.set(cache_key, keys, ttl=3600) # Cache for 1 hour + await async_provider_service_cache.set(cache_key, keys, ttl=3600) # Cache for 1 hour return keys @@ -257,13 +216,12 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]: # Query ProviderKey directly by user_id from app.models.provider_key import ProviderKey - provider_key_records = ( - self.db.query(ProviderKey).filter(ProviderKey.user_id == self.user_id).all() - ) + result = await self.db.execute(select(ProviderKey).filter(ProviderKey.user_id == self.user_id)) + provider_key_records = result.scalars().all() keys = {} for provider_key in provider_key_records: - model_mapping = self._parse_model_mapping(provider_key.model_mapping) + model_mapping = provider_key.model_mapping or {} keys[provider_key.provider_name] = { "api_key": decrypt_api_key(provider_key.encrypted_api_key), @@ -414,7 +372,7 @@ async def list_models( cache_key = f"{base_url}:{hash(frozenset(provider_data.get('model_mapping', {}).items()))}" # Check if we have cached models for this provider - cached_models = self.get_cached_models(provider_name, cache_key) + cached_models = await self.get_cached_models(provider_name, cache_key) if cached_models: models.extend(cached_models) continue @@ -441,7 +399,7 @@ async def _list_models_helper( for model in model_names ] # Cache the results - self.cache_models(provider_name, cache_key, provider_models) + await self.cache_models(provider_name, cache_key, provider_models) return provider_models except Exception as e: @@ -580,12 +538,9 @@ async def process_request( # Record the usage statistics using the new logging method # Use a fresh DB session for logging, since the original request session # may have been closed by FastAPI after the response was returned. - from sqlalchemy.orm import Session + from app.core.database import get_db_session - from app.core.database import SessionLocal - - new_db_session: Session = SessionLocal() - try: + async with get_db_session() as new_db_session: await UsageStatsService.log_api_request( db=new_db_session, user_id=self.user_id, @@ -595,9 +550,7 @@ async def process_request( input_tokens=input_tokens, output_tokens=output_tokens, ) - finally: - new_db_session.close() - return result + return result else: # For streaming responses, wrap the generator to count tokens async def token_counting_stream() -> AsyncGenerator[bytes, None]: @@ -692,12 +645,9 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: # Use a fresh DB session for logging, since the original request session # may have been closed by FastAPI after the response was returned. - from sqlalchemy.orm import Session - - from app.core.database import SessionLocal + from app.core.database import get_db_session - new_db_session: Session = SessionLocal() - try: + async with get_db_session() as new_db_session: await UsageStatsService.log_api_request( db=new_db_session, user_id=self.user_id, @@ -707,14 +657,12 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: input_tokens=input_tokens, output_tokens=output_tokens, ) - finally: - new_db_session.close() # End of token_counting_stream function return token_counting_stream() -def create_default_tensorblock_provider_for_user(user_id: int, db: Session) -> None: +async def create_default_tensorblock_provider_for_user(user_id: int, db: AsyncSession) -> None: """ Create a default TensorBlock provider key for a new user. This allows users to use Forge immediately without binding their own API keys. @@ -756,12 +704,12 @@ def create_default_tensorblock_provider_for_user(user_id: int, db: Session) -> N ) db.add(provider_key) - db.commit() + await db.commit() logger.info(f"Created default TensorBlock provider for user {user_id}") except Exception as e: - db.rollback() + await db.rollback() logger.error( "Error creating default TensorBlock provider for user {}: {}", user_id, diff --git a/app/services/usage_stats_service.py b/app/services/usage_stats_service.py index 4aec1f2..807f32a 100644 --- a/app/services/usage_stats_service.py +++ b/app/services/usage_stats_service.py @@ -2,7 +2,7 @@ from typing import Any from sqlalchemy import func, select -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from app.core.logger import get_logger from app.models.api_request_log import ApiRequestLog @@ -15,7 +15,7 @@ class UsageStatsService: @staticmethod async def log_api_request( - db: Session, + db: AsyncSession, user_id: int | None, provider_name: str, model: str, @@ -39,19 +39,19 @@ async def log_api_request( cost=cost, ) db.add(log_entry) - db.commit() + await db.commit() logger.debug( f"Logged API request for user {user_id}: {provider_name}/{model}/{endpoint}" ) except Exception as e: - db.rollback() + await db.rollback() logger.error( f"Failed to log API request for user {user_id}: {e}", exc_info=True ) @staticmethod - def get_user_stats( - db: Session, + async def get_user_stats( + db: AsyncSession, user_id: int, provider: str | None = None, model: str | None = None, @@ -89,7 +89,7 @@ def get_user_stats( query = query.group_by(ApiRequestLog.provider_name, ApiRequestLog.model) - results = db.execute(query).fetchall() + results = await db.execute(query) return [ { @@ -105,8 +105,8 @@ def get_user_stats( ] @staticmethod - def get_all_stats( - db: Session, + async def get_all_stats( + db: AsyncSession, provider: str | None = None, # Add provider filter model: str | None = None, # Add model filter start_date: date | None = None, # Add start_date filter @@ -148,7 +148,7 @@ def get_all_stats( query = query.group_by(ApiRequestLog.provider_name, ApiRequestLog.model) # Execute query - results = db.execute(query).fetchall() + results = await db.execute(query) # Convert results to dictionaries return [ diff --git a/forge-cli.py b/forge-cli.py index 1947449..05f1c1b 100755 --- a/forge-cli.py +++ b/forge-cli.py @@ -219,6 +219,38 @@ def list_provider_keys(self) -> list[dict[str, Any]]: except Exception as e: print(f"❌ Error listing provider keys: {str(e)}") return [] + + def update_provider_key(self, provider_name: str, api_key: str | None = None, base_url: str | None = None, model_mapping: str | None = None, config: str | None = None) -> bool: + """Update a provider key""" + if not self.token: + print("❌ Not authenticated. Please login first.") + return False + + url = f"{self.api_url}/provider-keys/{provider_name}" + headers = { + "Authorization": f"Bearer {self.token}", + "Content-Type": "application/json", + } + data = { + "api_key": api_key, + "base_url": base_url, + "model_mapping": json.loads(model_mapping) if model_mapping else None, + "config": json.loads(config) if config else None, + } + try: + response = requests.put(url, headers=headers, json=data) + + if response.status_code == HTTPStatus.OK: + print(f"✅ Successfully updated {provider_name} API key!") + return True + else: + print(f"❌ Error updated provider key: {response.status_code}") + print(response.text) + return False + except Exception as e: + print(f"❌ Error updating provider key: {str(e)}") + return False + def delete_provider_key(self, provider_name: str) -> bool: """Delete a provider key""" @@ -233,7 +265,7 @@ def delete_provider_key(self, provider_name: str) -> bool: response = requests.delete(url, headers=headers) if response.status_code == HTTPStatus.OK: - print(f"✅ Successfully deleted {provider_name} API key!") + print(f"✅ Successfully deleted provider key {provider_name}!") return True else: print(f"❌ Error deleting provider key: {response.status_code}") @@ -462,8 +494,9 @@ def main(): print("8. Add Provider Key") print("9. List Provider Keys") print("10. Delete Provider Key") - print("11. Test Chat Completion") - print("12. List Models") + print("11. Update Provider Key") + print("12. Test Chat Completion") + print("13. List Models") print("0. Exit") choice = input("\nEnter your choice (0-12): ") @@ -614,7 +647,7 @@ def main(): key = getpass("Enter provider API key: ") base_url = input("Enter provider base URL (optional, press Enter to skip): ") config = input("Enter provider config in json string format (optional, press Enter to skip): ") - model_mapping = input("Enter model ampping config in json string format (optional, press Enter to skip): ") + model_mapping = input("Enter model mapping config in json string format (optional, press Enter to skip): ") forge.add_provider_key(provider, key, base_url=base_url, config=config, model_mapping=model_mapping) elif choice == "9": @@ -624,10 +657,24 @@ def main(): forge.list_provider_keys() elif choice == "10": - provider = input("Enter provider name to delete: ") - forge.delete_provider_key(provider) + if not forge.token: + token = input("Enter JWT token: ") + forge.token = token + provider_name = input("Enter provider name to delete: ") + forge.delete_provider_key(provider_name) elif choice == "11": + if not forge.token: + token = input("Enter JWT token: ") + forge.token = token + provider_name = input("Enter provider name to update: ") + api_key = getpass("Enter provider API key: ") + base_url = input("Enter provider base URL (optional, press Enter to skip): ") + config = input("Enter provider config in json string format (optional, press Enter to skip): ") + model_mapping = input("Enter model mapping config in json string format (optional, press Enter to skip): ") + forge.update_provider_key(provider_name, api_key, base_url=base_url, config=config, model_mapping=model_mapping) + + elif choice == "12": model = input("Enter model ID: ") message = input("Enter message: ") api_key = input("Enter your Forge API key: ").strip() @@ -642,7 +689,7 @@ def main(): continue forge.test_chat_completion(model, message, api_key) - elif choice == "12": + elif choice == "13": api_key = input( "Enter your Forge API key (or press Enter to use stored key if available): " ).strip() diff --git a/pyproject.toml b/pyproject.toml index 0471c82..5075668 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "python-jose>=3.3.0", "passlib>=1.7.4", "python-multipart>=0.0.5", - "sqlalchemy>=2.0.0", + "sqlalchemy[asyncio]>=2.0.0", "alembic>=1.10.4", "aiohttp>=3.8.4", "cryptography>=40.0.0", diff --git a/tests/cache/test_async_cache.py b/tests/cache/test_async_cache.py index 115693e..ac32cd7 100755 --- a/tests/cache/test_async_cache.py +++ b/tests/cache/test_async_cache.py @@ -11,6 +11,8 @@ from pathlib import Path from unittest.mock import MagicMock +from sqlalchemy import delete + from app.core.async_cache import ( AsyncCache, async_cached, @@ -188,7 +190,7 @@ async def test_model_list_async_cache(): ProviderService.cache_models(provider_name, cache_key, mock_models) # Test retrieving from cache - cached_models = ProviderService.get_cached_models(provider_name, cache_key) + cached_models = await ProviderService.get_cached_models(provider_name, cache_key) assert cached_models is not None, "Async model list cache get failed" assert len(cached_models) == EXPECTED_MODEL_COUNT, "Model list length mismatch" assert cached_models[FIRST_MODEL_INDEX]["id"] == "gpt-4", "Model ID mismatch" @@ -197,7 +199,7 @@ async def test_model_list_async_cache(): # Test cache invalidation ProviderService._models_cache = {} ProviderService._models_cache_expiry = {} - cached_models = ProviderService.get_cached_models(provider_name, cache_key) + cached_models = await ProviderService.get_cached_models(provider_name, cache_key) assert cached_models is None, "Async model list cache invalidation failed" print("✅ Async model list cache test passed") @@ -354,14 +356,14 @@ async def test_async_cache_invalidation(): model_cache_key = "default" # Set model list in cache - ProviderService.cache_models(provider_name, model_cache_key, mock_models) - cached_models = ProviderService.get_cached_models(provider_name, model_cache_key) + await ProviderService.cache_models(provider_name, model_cache_key, mock_models) + cached_models = await ProviderService.get_cached_models(provider_name, model_cache_key) assert cached_models is not None, "Model list cache set failed" # Invalidate model list cache ProviderService._models_cache = {} ProviderService._models_cache_expiry = {} - cached_models = ProviderService.get_cached_models(provider_name, model_cache_key) + cached_models = await ProviderService.get_cached_models(provider_name, model_cache_key) assert cached_models is None, "Model list cache invalidation failed" # Test 4: Provider service instance cache invalidation @@ -511,55 +513,55 @@ async def test_async_cache_warming(): from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker + from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession + from app.models.base import Base # Use SQLite in-memory database for testing - engine = create_engine("sqlite:///:memory:") + engine = create_async_engine("sqlite:///:memory:") # Create all tables Base.metadata.create_all(bind=engine) - testing = sessionmaker(autocommit=False, autoflush=False, bind=engine) - db = testing() - - try: - # Create test user and API key - user = User( - email="test@example.com", - username="testuser", - is_active=True, - hashed_password="dummy_hash", - ) - db.add(user) - db.commit() - - # Create a test API key - test_api_key = "test_key_123" - encrypted_key = encrypt_api_key(test_api_key) - - provider_key = ProviderKey( - user_id=user.id, - provider_name="test_provider", - encrypted_api_key=encrypted_key, - ) - db.add(provider_key) - db.commit() - - # Warm the cache - await warm_cache_async(db) - - # Verify user is cached with the correct API key - cached_user = await get_cached_user_async(test_api_key) - assert cached_user is not None - assert cached_user.id == user.id - - finally: - # Clean up - db.query(ProviderKey).delete() - db.query(User).delete() - db.commit() - db.close() - # Drop all tables - Base.metadata.drop_all(bind=engine) + testing = async_sessionmaker(autocommit=False, autoflush=False, bind=engine) + async with testing() as db: + try: + # Create test user and API key + user = User( + email="test@example.com", + username="testuser", + is_active=True, + hashed_password="dummy_hash", + ) + db.add(user) + db.commit() + + # Create a test API key + test_api_key = "test_key_123" + encrypted_key = encrypt_api_key(test_api_key) + + provider_key = ProviderKey( + user_id=user.id, + provider_name="test_provider", + encrypted_api_key=encrypted_key, + ) + db.add(provider_key) + db.commit() + + # Warm the cache + await warm_cache_async(db) + + # Verify user is cached with the correct API key + cached_user = await get_cached_user_async(test_api_key) + assert cached_user is not None + assert cached_user.id == user.id + + finally: + # Clean up + await db.execute(delete(ProviderKey)) + await db.execute(delete(User)) + await db.commit() + # Drop all tables + Base.metadata.drop_all(bind=engine) return True diff --git a/tests/cache/test_sync_cache.py b/tests/cache/test_sync_cache.py deleted file mode 100755 index a9a696d..0000000 --- a/tests/cache/test_sync_cache.py +++ /dev/null @@ -1,374 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script to verify all types of caching in the system. -Tests user cache, provider service cache, provider keys cache, and model list cache. -""" - -import os -import sys -import time -from pathlib import Path -from unittest.mock import MagicMock - -from app.core.cache import ( - Cache, - invalidate_provider_models_cache, - invalidate_user_cache_by_id, - monitor_cache_performance, - provider_service_cache, - user_cache, - warm_cache, -) -from app.models.user import User -from app.services.provider_service import ProviderService - -# Add constants at the top of the file, after imports -EXPECTED_MODEL_COUNT = 2 # Expected number of models in test data -FIRST_MODEL_INDEX = 0 # Index of first model in test data -SECOND_MODEL_INDEX = 1 # Index of second model in test data - -# Add the project root to the Python path -script_dir = Path(__file__).resolve().parent.parent.parent -sys.path.insert(0, str(script_dir)) - -# Change to the project root directory -os.chdir(script_dir) - -# Set environment variables before importing cache modules -os.environ["FORCE_MEMORY_CACHE"] = "true" -os.environ["DEBUG_CACHE"] = "true" - - -# Clear caches before each test -def setup_function(function): - user_cache.clear() - provider_service_cache.clear() - ProviderService._models_cache = {} - ProviderService._models_cache_expiry = {} - ProviderService._models_l1_cache = {} - - -def test_basic_cache_operations(): - """Test basic cache operations""" - print("\n🔍 TESTING BASIC CACHE OPERATIONS") - print("================================") - - # Create a new cache instance - cache = Cache(ttl_seconds=5) - - # Test set and get - cache.set("test_key", "test_value") - value = cache.get("test_key") - assert value == "test_value", "Cache get/set failed" - - # Test TTL - cache.set("expiring_key", "expiring_value", ttl=1) - time.sleep(1.1) # Wait for TTL to expire - value = cache.get("expiring_key") - assert value is None, "TTL expiration failed" - - # Test delete - cache.set("delete_key", "delete_value") - cache.delete("delete_key") - value = cache.get("delete_key") - assert value is None, "Cache delete failed" - - # Test clear - cache.set("clear_key", "clear_value") - cache.clear() - value = cache.get("clear_key") - assert value is None, "Cache clear failed" - - print("✅ Basic cache operations test passed") - - -def test_user_cache(): - """Test user caching functionality""" - print("\n🔍 TESTING USER CACHE") - print("====================") - - # Create mock user - mock_user = User( - id=1, email="test@example.com", username="testuser", is_active=True - ) - - # Test caching user - api_key = "test_api_key_123" - user_cache.set(f"user:{api_key}", mock_user) - - # Test retrieving from cache - cached_user = user_cache.get(f"user:{api_key}") - assert cached_user is not None, "User cache get failed" - assert cached_user.id == mock_user.id, "Cached user ID mismatch" - assert cached_user.email == mock_user.email, "Cached user email mismatch" - - # Test cache invalidation - user_cache.delete(f"user:{api_key}") - cached_user = user_cache.get(f"user:{api_key}") - assert cached_user is None, "User cache invalidation failed" - - print("✅ User cache test passed") - - -def test_provider_keys_cache(): - """Test provider keys caching functionality""" - print("\n🔍 TESTING PROVIDER KEYS CACHE") - print("============================") - - # Create mock provider keys - mock_keys = { - "openai": { - "api_key": "sk_test_123", - "base_url": "https://api.openai.com/v1", - "model_mapping": {"gpt-4": "gpt-4-turbo"}, - }, - "anthropic": { - "api_key": "sk-ant-test-123", - "base_url": "https://api.anthropic.com/v1", - "model_mapping": {}, - }, - } - - # Test caching provider keys - user_id = 1 - cache_key = f"provider_keys:{user_id}" - provider_service_cache.set(cache_key, mock_keys, ttl=3600) - - # Test retrieving from cache - cached_keys = provider_service_cache.get(cache_key) - assert cached_keys is not None, "Provider keys cache get failed" - assert "openai" in cached_keys, "OpenAI provider key missing from cache" - assert "anthropic" in cached_keys, "Anthropic provider key missing from cache" - assert ( - cached_keys["openai"]["model_mapping"]["gpt-4"] == "gpt-4-turbo" - ), "Model mapping mismatch" - - # Test cache invalidation - provider_service_cache.delete(cache_key) - cached_keys = provider_service_cache.get(cache_key) - assert cached_keys is None, "Provider keys cache invalidation failed" - - print("✅ Provider keys cache test passed") - - -def test_model_list_cache(): - """Test model list caching functionality""" - print("\n🔍 TESTING MODEL LIST CACHE") - print("=========================") - - # Create mock model list - mock_models = [ - { - "id": "gpt-4", - "display_name": "GPT-4", - "object": "model", - "owned_by": "openai", - }, - { - "id": "claude-3", - "display_name": "Claude 3", - "object": "model", - "owned_by": "anthropic", - }, - ] - - # Test caching models - provider_name = "openai" - cache_key = "default" - ProviderService.cache_models(provider_name, cache_key, mock_models) - - # Test retrieving from cache - cached_models = ProviderService.get_cached_models(provider_name, cache_key) - assert cached_models is not None, "Model list cache get failed" - assert len(cached_models) == EXPECTED_MODEL_COUNT, "Model list length mismatch" - assert cached_models[FIRST_MODEL_INDEX]["id"] == "gpt-4", "Model ID mismatch" - assert cached_models[SECOND_MODEL_INDEX]["id"] == "claude-3", "Model ID mismatch" - - # Test cache invalidation - invalidate_provider_models_cache(provider_name) - cached_models = ProviderService.get_cached_models(provider_name, cache_key) - assert cached_models is None, "Model list cache invalidation failed" - - print("✅ Model list cache test passed") - - -def test_cache_invalidation(): - """Test cache invalidation scenarios""" - print("\n🔍 TESTING CACHE INVALIDATION") - print("===========================") - - # Test 1: User cache invalidation - print("\n🔄 Test 1: User cache invalidation") - mock_user = User( - id=1, email="test@example.com", username="testuser", is_active=True - ) - api_key = "test_api_key_123" - - # Set user in cache - user_cache.set(f"user:{api_key}", mock_user) - assert user_cache.get(f"user:{api_key}") is not None, "User cache set failed" - - # Invalidate user cache - user_cache.delete(f"user:{api_key}") - assert user_cache.get(f"user:{api_key}") is None, "User cache invalidation failed" - - # Test 2: Provider keys cache invalidation - print("\n🔄 Test 2: Provider keys cache invalidation") - mock_keys = { - "openai": { - "api_key": "sk_test_123", - "base_url": "https://api.openai.com/v1", - "model_mapping": {"gpt-4": "gpt-4-turbo"}, - } - } - user_id = 1 - cache_key = f"provider_keys:{user_id}" - - # Set provider keys in cache - provider_service_cache.set(cache_key, mock_keys, ttl=3600) - assert ( - provider_service_cache.get(cache_key) is not None - ), "Provider keys cache set failed" - - # Invalidate provider keys cache - provider_service_cache.delete(cache_key) - assert ( - provider_service_cache.get(cache_key) is None - ), "Provider keys cache invalidation failed" - - # Test 3: Model list cache invalidation - print("\n🔄 Test 3: Model list cache invalidation") - mock_models = [ - { - "id": "gpt-4", - "display_name": "GPT-4", - "object": "model", - "owned_by": "openai", - } - ] - provider_name = "openai" - model_cache_key = "default" - - # Set model list in cache - ProviderService.cache_models(provider_name, model_cache_key, mock_models) - assert ( - ProviderService.get_cached_models(provider_name, model_cache_key) is not None - ), "Model list cache set failed" - - # Invalidate model list cache - invalidate_provider_models_cache(provider_name) - assert ( - ProviderService.get_cached_models(provider_name, model_cache_key) is None - ), "Model list cache invalidation failed" - - -def test_cache_invalidation_by_id(): - """Test cache invalidation by user ID""" - print("\nTesting cache invalidation by user ID...") - - # Create test user - user = User( - id=1, - email="test@example.com", - username="testuser", - is_active=True, - hashed_password="dummy_hash", - ) - - # Create test API keys - api_key1 = "test_key_1" - api_key2 = "test_key_2" - - # Cache user with multiple API keys - user_cache.set(f"user:{api_key1}", user) - user_cache.set(f"user:{api_key2}", user) - - # Verify user is cached - cached_user1 = user_cache.get(f"user:{api_key1}") - cached_user2 = user_cache.get(f"user:{api_key2}") - assert cached_user1 is not None - assert cached_user2 is not None - assert cached_user1.id == user.id - assert cached_user2.id == user.id - - # Invalidate all cache entries for this user - invalidate_user_cache_by_id(user.id) - - # Verify cache is cleared - assert user_cache.get(f"user:{api_key1}") is None - assert user_cache.get(f"user:{api_key2}") is None - - -def test_provider_models_cache_invalidation(): - """Test provider models cache invalidation""" - print("\nTesting provider models cache invalidation...") - - # Set up test data - provider_name = "test_provider" - models = [{"id": "model1"}, {"id": "model2"}] - cache_key = "default" - - # Cache models using the public API - ProviderService.cache_models(provider_name, cache_key, models) - - # Verify models are cached - cached_models = ProviderService.get_cached_models(provider_name, cache_key) - assert cached_models is not None - assert len(cached_models) == 2 - assert cached_models[0]["id"] == "model1" - assert cached_models[1]["id"] == "model2" - - # Invalidate cache - invalidate_provider_models_cache(provider_name) - - # Verify cache is cleared - cached_models = ProviderService.get_cached_models(provider_name, cache_key) - assert cached_models is None, "Provider models cache invalidation failed" - - -def test_cache_stats_and_monitoring(): - """Test cache statistics and monitoring""" - print("\n🔍 TESTING CACHE STATS AND MONITORING") - print("===================================") - - # Test basic cache operations to generate stats - cache = Cache(ttl_seconds=5) - cache.set("test_key", "test_value") - cache.get("test_key") # Hit - cache.get("nonexistent") # Miss - - # Test stats - stats = cache.stats() - assert stats["hits"] == 1, "Cache hit count mismatch" - assert stats["misses"] == 1, "Cache miss count mismatch" - assert stats["total"] == 2, "Cache total count mismatch" - assert stats["hit_rate"] == 0.5, "Cache hit rate mismatch" - assert stats["entries"] == 1, "Cache entries count mismatch" - - # Test monitoring - monitoring = monitor_cache_performance() - assert "stats" in monitoring, "Cache stats missing" - assert "overall_hit_rate" in monitoring, "Overall hit rate missing" - assert "issues" in monitoring, "Issues list missing" - - print("✅ Cache stats and monitoring test passed") - - -async def test_cache_warming(): - """Test cache warming functionality""" - print("\n🔍 TESTING CACHE WARMING") - print("=======================") - - # Mock database session - db = MagicMock() - - # Test cache warming - await warm_cache(db) - - # Verify cache is populated - assert user_cache.stats()["entries"] > 0, "User cache not warmed" - assert ( - provider_service_cache.stats()["entries"] > 0 - ), "Provider service cache not warmed" - - print("✅ Cache warming test passed") diff --git a/tests/mock_testing/add_mock_provider.py b/tests/mock_testing/add_mock_provider.py index 6dc6c8f..795ac91 100755 --- a/tests/mock_testing/add_mock_provider.py +++ b/tests/mock_testing/add_mock_provider.py @@ -4,6 +4,7 @@ This allows users to test the Forge middleware without needing real API keys. """ +import asyncio import argparse import json import os @@ -13,10 +14,11 @@ from dotenv import load_dotenv from app.core.cache import provider_service_cache -from app.core.database import SessionLocal +from app.core.database import get_db_session from app.core.security import encrypt_api_key from app.models.provider_key import ProviderKey from app.models.user import User +from sqlalchemy import select # Add the project root to the Python path script_dir = Path(__file__).resolve().parent.parent.parent @@ -29,91 +31,86 @@ os.chdir(script_dir) -def setup_mock_provider(username: str, force: bool = False): +async def setup_mock_provider(username: str, force: bool = False): """Add a mock provider key to the specified user account""" # Create a database session - db = SessionLocal() + async with get_db_session() as db: + try: + # Find the user + result = await db.execute(select(User).filter(User.username == username)) + user = result.scalar_one_or_none() + if not user: + print(f"❌ User '{username}' not found. Please provide a valid username.") + return False + + # Check if the mock provider already exists for this user + result = await db.execute(select(ProviderKey).filter(ProviderKey.user_id == user.id, ProviderKey.provider_name == "mock")) + existing_provider = result.scalar_one_or_none() + + if existing_provider and not force: + print(f"⚠️ Mock provider already exists for user '{username}'.") + print("Use --force to replace it.") + return False + + # If force is set and provider exists, delete the existing one + if existing_provider and force: + db.delete(existing_provider) + db.commit() + print(f"🗑️ Deleted existing mock provider for user '{username}'.") + + # Create a mock API key - it doesn't need to be secure as it's not used + mock_api_key = "mock-api-key-for-testing-purposes" + encrypted_key = encrypt_api_key(mock_api_key) + + # Create model mappings for common models to their mock equivalents + model_mapping = { + "mock-only-gpt-3.5-turbo": "mock-gpt-3.5-turbo", + "mock-only-gpt-4": "mock-gpt-4", + "mock-only-gpt-4o": "mock-gpt-4o", + "mock-only-claude-3-opus": "mock-claude-3-opus", + "mock-only-claude-3-sonnet": "mock-claude-3-sonnet", + "mock-only-claude-3-haiku": "mock-claude-3-haiku", + } + + # Create the provider key + provider_key = ProviderKey( + user_id=user.id, + provider_name="mock", + encrypted_api_key=encrypted_key, + model_mapping=json.dumps( + model_mapping + ), # Use json.dumps for proper storage + ) + + db.add(provider_key) + db.commit() - try: - # Find the user - user = db.query(User).filter(User.username == username).first() - if not user: - print(f"❌ User '{username}' not found. Please provide a valid username.") + # Invalidate provider key cache for this user to force refresh + provider_service_cache.delete(f"provider_keys:{user.id}") + print(f"✅ Invalidated provider key cache for user '{username}'") + + print(f"✅ Successfully added mock provider for user '{username}'.") + print( + f"🔑 Mock API Key: {mock_api_key} (not a real key, used for testing only)" + ) + print("") + print("You can now use the following models with this provider:") + for original, mock in model_mapping.items(): + print(f" - {original} -> {mock}") + print("") + print( + "Use these models with your Forge API Key to test the middleware without real API calls." + ) + + return True + + except Exception as e: + await db.rollback() + print(f"❌ Error setting up mock provider: {str(e)}") return False - # Check if the mock provider already exists for this user - existing_provider = ( - db.query(ProviderKey) - .filter(ProviderKey.user_id == user.id, ProviderKey.provider_name == "mock") - .first() - ) - - if existing_provider and not force: - print(f"⚠️ Mock provider already exists for user '{username}'.") - print("Use --force to replace it.") - return False - # If force is set and provider exists, delete the existing one - if existing_provider and force: - db.delete(existing_provider) - db.commit() - print(f"🗑️ Deleted existing mock provider for user '{username}'.") - - # Create a mock API key - it doesn't need to be secure as it's not used - mock_api_key = "mock-api-key-for-testing-purposes" - encrypted_key = encrypt_api_key(mock_api_key) - - # Create model mappings for common models to their mock equivalents - model_mapping = { - "mock-only-gpt-3.5-turbo": "mock-gpt-3.5-turbo", - "mock-only-gpt-4": "mock-gpt-4", - "mock-only-gpt-4o": "mock-gpt-4o", - "mock-only-claude-3-opus": "mock-claude-3-opus", - "mock-only-claude-3-sonnet": "mock-claude-3-sonnet", - "mock-only-claude-3-haiku": "mock-claude-3-haiku", - } - - # Create the provider key - provider_key = ProviderKey( - user_id=user.id, - provider_name="mock", - encrypted_api_key=encrypted_key, - model_mapping=json.dumps( - model_mapping - ), # Use json.dumps for proper storage - ) - - db.add(provider_key) - db.commit() - - # Invalidate provider key cache for this user to force refresh - provider_service_cache.delete(f"provider_keys:{user.id}") - print(f"✅ Invalidated provider key cache for user '{username}'") - - print(f"✅ Successfully added mock provider for user '{username}'.") - print( - f"🔑 Mock API Key: {mock_api_key} (not a real key, used for testing only)" - ) - print("") - print("You can now use the following models with this provider:") - for original, mock in model_mapping.items(): - print(f" - {original} -> {mock}") - print("") - print( - "Use these models with your Forge API Key to test the middleware without real API calls." - ) - - return True - - except Exception as e: - db.rollback() - print(f"❌ Error setting up mock provider: {str(e)}") - return False - finally: - db.close() - - -def main(): +async def main(): parser = argparse.ArgumentParser( description="Add a mock provider to a user account for testing" ) @@ -136,4 +133,4 @@ def main(): if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/tests/mock_testing/setup_test_user.py b/tests/mock_testing/setup_test_user.py index b90811a..defe1df 100644 --- a/tests/mock_testing/setup_test_user.py +++ b/tests/mock_testing/setup_test_user.py @@ -4,6 +4,7 @@ This script is used to prepare the environment for testing the mock provider. """ +import asyncio import json import os import sys @@ -11,9 +12,10 @@ from dotenv import load_dotenv from passlib.context import CryptContext +from sqlalchemy import select from app.core.cache import invalidate_user_cache, provider_service_cache -from app.core.database import SessionLocal +from app.core.database import get_db_session from app.core.security import encrypt_api_key from app.models.provider_key import ProviderKey from app.models.user import User @@ -37,130 +39,123 @@ MOCK_PROVIDER_API_KEY = "mock-api-key-for-testing-purposes" -def create_or_update_test_user(): +async def create_or_update_test_user(): """Create a test user with a known Forge API key or update existing user""" - db = SessionLocal() - - try: - # Try to find user by username first - user = db.query(User).filter(User.username == TEST_USERNAME).first() - - # If not found by username, try by email - if not user: - user = db.query(User).filter(User.email == TEST_EMAIL).first() - - # If user exists, update the forge API key - if user: - print(f"✅ Found existing user: {user.username} (email: {user.email})") - old_key = user.forge_api_key - user.forge_api_key = TEST_FORGE_API_KEY - db.commit() - db.refresh(user) - # Invalidate the user in cache to force refresh with new API key - invalidate_user_cache(old_key) - invalidate_user_cache(TEST_FORGE_API_KEY) - print( - f"✅ Invalidated user cache for API keys: {old_key} and {TEST_FORGE_API_KEY}" + async with get_db_session() as db: + try: + # Try to find user by username first + result = await db.execute(select(User).filter(User.username == TEST_USERNAME)) + user = result.scalar_one_or_none() + + # If not found by username, try by email + if not user: + result = await db.execute(select(User).filter(User.email == TEST_EMAIL)) + user = result.scalar_one_or_none() + + # If user exists, update the forge API key + if user: + print(f"✅ Found existing user: {user.username} (email: {user.email})") + old_key = user.forge_api_key + user.forge_api_key = TEST_FORGE_API_KEY + await db.commit() + await db.refresh(user) + # Invalidate the user in cache to force refresh with new API key + invalidate_user_cache(old_key) + invalidate_user_cache(TEST_FORGE_API_KEY) + print( + f"✅ Invalidated user cache for API keys: {old_key} and {TEST_FORGE_API_KEY}" + ) + print(f"🔄 Updated Forge API Key: {old_key} -> {user.forge_api_key}") + return user + + # Create new user if not exists + hashed_password = pwd_context.hash(TEST_PASSWORD) + user = User( + username=TEST_USERNAME, + email=TEST_EMAIL, + hashed_password=hashed_password, + forge_api_key=TEST_FORGE_API_KEY, + is_active=True, ) - print(f"🔄 Updated Forge API Key: {old_key} -> {user.forge_api_key}") + db.add(user) + await db.commit() + await db.refresh(user) + print(f"✅ Created test user '{TEST_USERNAME}'") + print(f"🔑 Forge API Key: {TEST_FORGE_API_KEY}") return user - # Create new user if not exists - hashed_password = pwd_context.hash(TEST_PASSWORD) - user = User( - username=TEST_USERNAME, - email=TEST_EMAIL, - hashed_password=hashed_password, - forge_api_key=TEST_FORGE_API_KEY, - is_active=True, - ) - db.add(user) - db.commit() - db.refresh(user) - print(f"✅ Created test user '{TEST_USERNAME}'") - print(f"🔑 Forge API Key: {TEST_FORGE_API_KEY}") - return user - - except Exception as e: - db.rollback() - print(f"❌ Error creating/updating test user: {str(e)}") - return None - finally: - db.close() - - -def add_mock_provider_to_user(user_id): - """Add a mock provider to the test user""" - db = SessionLocal() - - try: - # Check if the mock provider already exists for this user - existing_provider = ( - db.query(ProviderKey) - .filter(ProviderKey.user_id == user_id, ProviderKey.provider_name == "mock") - .first() - ) + except Exception as e: + await db.rollback() + print(f"❌ Error creating/updating test user: {str(e)}") + return None - if existing_provider: - print("✅ Mock provider already exists for the test user.") - return True - # Create a mock API key - encrypted_key = encrypt_api_key(MOCK_PROVIDER_API_KEY) - - # Create model mappings for common models to their mock equivalents - model_mapping = { - "mock-only-gpt-3.5-turbo": "mock-gpt-3.5-turbo", - "mock-only-gpt-4": "mock-gpt-4", - "mock-only-gpt-4o": "mock-gpt-4o", - "mock-only-claude-3-opus": "mock-claude-3-opus", - "mock-only-claude-3-sonnet": "mock-claude-3-sonnet", - "mock-only-claude-3-haiku": "mock-claude-3-haiku", - } - - # Create the provider key - provider_key = ProviderKey( - user_id=user_id, - provider_name="mock", - encrypted_api_key=encrypted_key, - model_mapping=json.dumps(model_mapping), - ) +async def add_mock_provider_to_user(user_id): + """Add a mock provider to the test user""" + async with get_db_session() as db: + try: + # Check if the mock provider already exists for this user + result = await db.execute(select(ProviderKey).filter(ProviderKey.user_id == user_id, ProviderKey.provider_name == "mock")) + existing_provider = result.scalar_one_or_none() + + if existing_provider: + print("✅ Mock provider already exists for the test user.") + return True + + # Create a mock API key + encrypted_key = encrypt_api_key(MOCK_PROVIDER_API_KEY) + + # Create model mappings for common models to their mock equivalents + model_mapping = { + "mock-only-gpt-3.5-turbo": "mock-gpt-3.5-turbo", + "mock-only-gpt-4": "mock-gpt-4", + "mock-only-gpt-4o": "mock-gpt-4o", + "mock-only-claude-3-opus": "mock-claude-3-opus", + "mock-only-claude-3-sonnet": "mock-claude-3-sonnet", + "mock-only-claude-3-haiku": "mock-claude-3-haiku", + } + + # Create the provider key + provider_key = ProviderKey( + user_id=user_id, + provider_name="mock", + encrypted_api_key=encrypted_key, + model_mapping=json.dumps(model_mapping), + ) - db.add(provider_key) - db.commit() + db.add(provider_key) + await db.commit() - # Invalidate provider key cache for this user to force refresh - provider_service_cache.delete(f"provider_keys:{user_id}") - print(f"✅ Invalidated provider key cache for user ID: {user_id}") + # Invalidate provider key cache for this user to force refresh + provider_service_cache.delete(f"provider_keys:{user_id}") + print(f"✅ Invalidated provider key cache for user ID: {user_id}") - print("✅ Successfully added mock provider for test user.") - print(f"🔑 Mock API Key: {MOCK_PROVIDER_API_KEY} (used for testing only)") - print("") - print("You can now use the following models with this provider:") - for original, mock in model_mapping.items(): - print(f" - {original} -> {mock}") + print("✅ Successfully added mock provider for test user.") + print(f"🔑 Mock API Key: {MOCK_PROVIDER_API_KEY} (used for testing only)") + print("") + print("You can now use the following models with this provider:") + for original, mock in model_mapping.items(): + print(f" - {original} -> {mock}") - return True + return True - except Exception as e: - db.rollback() - print(f"❌ Error setting up mock provider: {str(e)}") - return False - finally: - db.close() + except Exception as e: + await db.rollback() + print(f"❌ Error setting up mock provider: {str(e)}") + return False -def main(): +async def main(): """Set up a test user with a mock provider""" print("🔄 Setting up test user with mock provider...") # Create or update test user - user = create_or_update_test_user() + user = await create_or_update_test_user() if not user: sys.exit(1) # Add mock provider to user - if add_mock_provider_to_user(user.id): + if await add_mock_provider_to_user(user.id): print("") print("✅ Setup complete!") print("📝 To test the mock provider, run:") @@ -173,4 +168,4 @@ def main(): if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/tests/unit_tests/test_provider_service.py b/tests/unit_tests/test_provider_service.py index 5ee8c58..51f8c6d 100644 --- a/tests/unit_tests/test_provider_service.py +++ b/tests/unit_tests/test_provider_service.py @@ -22,13 +22,13 @@ from app.models.user import User from app.services.provider_service import ProviderService from app.services.providers.adapter_factory import ProviderAdapterFactory -from app.core.cache import provider_service_cache, user_cache +from app.core.async_cache import async_provider_service_cache, async_user_cache class TestProviderService(TestCase): """Test the provider service""" - def setUp(self): + async def asyncSetUp(self): # Reset the adapters cache ProviderService._adapters_cache = {} @@ -38,63 +38,55 @@ def setUp(self): self.provider_key_openai.provider_name = "openai" self.provider_key_openai.encrypted_api_key = "encrypted_openai_key" self.provider_key_openai.base_url = None - self.provider_key_openai.model_mapping = json.dumps({"custom-gpt": "gpt-4"}) + self.provider_key_openai.model_mapping = {"custom-gpt": "gpt-4"} self.provider_key_anthropic = MagicMock(spec=ProviderKey) self.provider_key_anthropic.provider_name = "anthropic" self.provider_key_anthropic.encrypted_api_key = "encrypted_anthropic_key" self.provider_key_anthropic.base_url = None - self.provider_key_anthropic.model_mapping = "{}" + self.provider_key_anthropic.model_mapping = {} self.provider_key_google = MagicMock(spec=ProviderKey) self.provider_key_google.provider_name = "gemini" self.provider_key_google.encrypted_api_key = "encrypted_gemini_key" self.provider_key_google.base_url = None - self.provider_key_google.model_mapping = json.dumps( - {"test-gemini": "models/gemini-2.0-flash"} - ) + self.provider_key_google.model_mapping = {"test-gemini": "models/gemini-2.0-flash"} self.provider_key_xai = MagicMock(spec=ProviderKey) self.provider_key_xai.provider_name = "xai" self.provider_key_xai.encrypted_api_key = "encrypted_xai_key" self.provider_key_xai.base_url = None - self.provider_key_xai.model_mapping = json.dumps({"test-xai": "grok-2-1212"}) + self.provider_key_xai.model_mapping = {"test-xai": "grok-2-1212"} self.provider_key_fireworks = MagicMock(spec=ProviderKey) self.provider_key_fireworks.provider_name = "fireworks" self.provider_key_fireworks.encrypted_api_key = "encrypted_fireworks_key" self.provider_key_fireworks.base_url = None - self.provider_key_fireworks.model_mapping = json.dumps( - {"test-fireworks": "accounts/fireworks/models/code-llama-7b"} - ) + self.provider_key_fireworks.model_mapping = {"test-fireworks": "accounts/fireworks/models/code-llama-7b"} self.provider_key_openrouter = MagicMock(spec=ProviderKey) self.provider_key_openrouter.provider_name = "openrouter" self.provider_key_openrouter.encrypted_api_key = "encrypted_openrouter_key" self.provider_key_openrouter.base_url = None - self.provider_key_openrouter.model_mapping = json.dumps( - {"test-openrouter": "gpt-4o"} - ) + self.provider_key_openrouter.model_mapping = {"test-openrouter": "gpt-4o"} self.provider_key_together = MagicMock(spec=ProviderKey) self.provider_key_together.provider_name = "together" self.provider_key_together.encrypted_api_key = "encrypted_together_key" self.provider_key_together.base_url = None - self.provider_key_together.model_mapping = json.dumps( - {"test-together": "UAE-Large-V1"} - ) + self.provider_key_together.model_mapping = {"test-together": "UAE-Large-V1"} self.provider_key_azure = MagicMock(spec=ProviderKey) self.provider_key_azure.provider_name = "azure" self.provider_key_azure.encrypted_api_key = "encrypted_azure_key" self.provider_key_azure.base_url = "https://test-azure.openai.com" - self.provider_key_azure.model_mapping = json.dumps({"test-azure": "gpt-4o"}) + self.provider_key_azure.model_mapping = {"test-azure": "gpt-4o"} self.provider_key_bedrock = MagicMock(spec=ProviderKey) self.provider_key_bedrock.provider_name = "bedrock" self.provider_key_bedrock.encrypted_api_key = "encrypted_bedrock_key" self.provider_key_bedrock.base_url = None - self.provider_key_bedrock.model_mapping = json.dumps({"test-bedrock": "claude-3-5-sonnet-20240620-v1:0"}) + self.provider_key_bedrock.model_mapping = {"test-bedrock": "claude-3-5-sonnet-20240620-v1:0"} self.user.provider_keys = [ self.provider_key_openai, @@ -108,12 +100,12 @@ def setUp(self): self.provider_key_bedrock, ] - # Mock DB - self.db = MagicMock() + # Mock AsyncSession DB + self.db = AsyncMock() # Clear caches - provider_service_cache.clear() - user_cache.clear() + await async_provider_service_cache.clear() + await async_user_cache.clear() # Create the service with patched decrypt_api_key to avoid actual decryption with patch("app.services.provider_service.decrypt_api_key") as mock_decrypt: @@ -141,8 +133,11 @@ def setUp(self): # Mock user.id for the new constructor signature self.user.id = 1 - # Mock the database query for provider keys - self.db.query.return_value.filter.return_value.all.return_value = [ + # Mock the async database execute() pattern for provider keys + # Create mock result object + mock_result = MagicMock() # Result object should be sync, not AsyncMock + mock_scalars = MagicMock() # Don't use AsyncMock for scalars object + mock_scalars.all.return_value = [ self.provider_key_openai, self.provider_key_anthropic, self.provider_key_google, @@ -153,11 +148,13 @@ def setUp(self): self.provider_key_azure, self.provider_key_bedrock, ] + mock_result.scalars.return_value = mock_scalars # scalars() returns sync object + self.db.execute = AsyncMock(return_value=mock_result) # Only execute() is async self.service = ProviderService(self.user.id, self.db) # Pre-load the keys for testing - self.service._load_provider_keys() + await self.service._load_provider_keys_async() async def test_load_provider_keys(self): """Test loading provider keys""" diff --git a/tests/unit_tests/test_provider_service_images.py b/tests/unit_tests/test_provider_service_images.py index 1451585..4792d1e 100644 --- a/tests/unit_tests/test_provider_service_images.py +++ b/tests/unit_tests/test_provider_service_images.py @@ -23,40 +23,38 @@ from app.models.user import User from app.services.provider_service import ProviderService from app.services.providers.adapter_factory import ProviderAdapterFactory -from app.core.cache import provider_service_cache, user_cache +from app.core.async_cache import async_provider_service_cache, async_user_cache class TestProviderServiceImages(TestCase): """Test cases for ProviderService images endpoints""" - def setUp(self): + async def asyncSetUp(self): # Mock user with provider keys self.user = MagicMock(spec=User) self.provider_key_openai = MagicMock(spec=ProviderKey) self.provider_key_openai.provider_name = "openai" self.provider_key_openai.encrypted_api_key = "encrypted_openai_key" self.provider_key_openai.base_url = None - self.provider_key_openai.model_mapping = json.dumps({"dall-e-2": "dall-e-2"}) + self.provider_key_openai.model_mapping = {"dall-e-2": "dall-e-2"} self.provider_key_anthropic = MagicMock(spec=ProviderKey) self.provider_key_anthropic.provider_name = "anthropic" self.provider_key_anthropic.encrypted_api_key = "encrypted_anthropic_key" self.provider_key_anthropic.base_url = None - self.provider_key_anthropic.model_mapping = json.dumps( - {"custom-anthropic": "claude-3-opus", "claude-3-opus": "claude-3-opus"} - ) + self.provider_key_anthropic.model_mapping = {"custom-anthropic": "claude-3-opus", "claude-3-opus": "claude-3-opus"} self.user.provider_keys = [ self.provider_key_openai, self.provider_key_anthropic, ] - # Mock DB - self.db = MagicMock() + # Mock AsyncSession DB + self.db = AsyncMock() # Clear caches - provider_service_cache.clear() - user_cache.clear() + await async_provider_service_cache.clear() + await async_user_cache.clear() # Remove ProviderService creation from setUp # It will be created in each test after patching @@ -85,15 +83,19 @@ async def test_process_request_images_generations_routing( # Create the service with the NEW constructor signature (user.id) self.user.id = 1 - # Mock the database query that the new loading mechanism uses - self.db.query.return_value.filter.return_value.all.return_value = [ + # Mock the async database execute() pattern for provider keys + mock_result = MagicMock() # Result object should be sync, not AsyncMock + mock_scalars = MagicMock() # Don't use AsyncMock for scalars object + mock_scalars.all.return_value = [ self.provider_key_openai, self.provider_key_anthropic, ] + mock_result.scalars.return_value = mock_scalars # scalars() returns sync object + self.db.execute = AsyncMock(return_value=mock_result) # Only execute() is async service = ProviderService(self.user.id, self.db) # Let the service load keys properly through the new mechanism - service._load_provider_keys() + await service._load_provider_keys_async() # mock openai image generation response # no need to mock the response for anthropic @@ -151,15 +153,19 @@ async def test_process_request_images_edits_routing( # Create the service with the NEW constructor signature (user.id) self.user.id = 1 - # Mock the database query that the new loading mechanism uses - self.db.query.return_value.filter.return_value.all.return_value = [ + # Mock the async database execute() pattern for provider keys + mock_result = MagicMock() # Result object should be sync, not AsyncMock + mock_scalars = MagicMock() # Don't use AsyncMock for scalars object + mock_scalars.all.return_value = [ self.provider_key_openai, self.provider_key_anthropic, ] + mock_result.scalars.return_value = mock_scalars # scalars() returns sync object + self.db.execute = AsyncMock(return_value=mock_result) # Only execute() is async service = ProviderService(self.user.id, self.db) # Let the service load keys properly through the new mechanism - service._load_provider_keys() + await service._load_provider_keys_async() # mock openai image edits response # no need to mock the response for anthropic diff --git a/tools/diagnostics/fix_model_mapping.py b/tools/diagnostics/fix_model_mapping.py index d883a30..97e5a76 100755 --- a/tools/diagnostics/fix_model_mapping.py +++ b/tools/diagnostics/fix_model_mapping.py @@ -4,11 +4,12 @@ Specifically for fixing the gpt-4o to mock-gpt-4o mapping issue. """ +import asyncio import os import sys from pathlib import Path -from app.core.database import get_db +from app.core.database import get_async_db # Add the project root to the Python path script_dir = Path(__file__).resolve().parent.parent.parent @@ -18,13 +19,14 @@ os.chdir(script_dir) -def fix_model_mappings(): +async def fix_model_mappings(): """Fix model mappings by clearing caches""" print("\n🔧 FIXING MODEL MAPPINGS") print("======================") # Get DB session - next(get_db()) + async with get_async_db() as db: + pass # Clear all caches to ensure changes take effect print("🔄 Invalidating provider service cache for all users") @@ -38,9 +40,9 @@ def fix_model_mappings(): return True -def main(): +async def main(): """Main entry point""" - if fix_model_mappings(): + if await fix_model_mappings(): print( "\n✅ Model mappings have been fixed. Use check_model_mappings.py to verify." ) @@ -51,4 +53,4 @@ def main(): if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/uv.lock b/uv.lock index a45c8d3..32eea87 100644 --- a/uv.lock +++ b/uv.lock @@ -532,7 +532,7 @@ dependencies = [ { name = "python-multipart" }, { name = "redis" }, { name = "requests" }, - { name = "sqlalchemy" }, + { name = "sqlalchemy", extra = ["asyncio"] }, { name = "svix" }, { name = "uvicorn" }, ] @@ -580,7 +580,7 @@ requires-dist = [ { name = "redis", specifier = ">=4.6.0" }, { name = "requests", specifier = ">=2.28.0" }, { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.2.0" }, - { name = "sqlalchemy", specifier = ">=2.0.0" }, + { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.0" }, { name = "svix", specifier = ">=1.13.0" }, { name = "uvicorn", specifier = ">=0.22.0" }, ] @@ -1394,6 +1394,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/7c/5fc8e802e7506fe8b55a03a2e1dab156eae205c91bee46305755e086d2e2/sqlalchemy-2.0.40-py3-none-any.whl", hash = "sha256:32587e2e1e359276957e6fe5dad089758bc042a971a8a09ae8ecf7a8fe23d07a", size = 1903894 }, ] +[package.optional-dependencies] +asyncio = [ + { name = "greenlet" }, +] + [[package]] name = "starlette" version = "0.46.1"