diff --git a/alembic/versions/4a685a55c5cd_create_usage_tracker_table.py b/alembic/versions/4a685a55c5cd_create_usage_tracker_table.py new file mode 100644 index 0000000..016af40 --- /dev/null +++ b/alembic/versions/4a685a55c5cd_create_usage_tracker_table.py @@ -0,0 +1,44 @@ +"""create usage tracker table + +Revision ID: 4a685a55c5cd +Revises: 9daf34d338f7 +Create Date: 2025-08-02 12:29:07.955645 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID +import uuid + + +# revision identifiers, used by Alembic. +revision = '4a685a55c5cd' +down_revision = '9daf34d338f7' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "usage_tracker", + sa.Column("id", UUID(as_uuid=True), nullable=False, primary_key=True, default=uuid.uuid4), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("provider_key_id", sa.Integer(), nullable=False), + sa.Column("forge_key_id", sa.Integer(), nullable=False), + sa.Column("model", sa.String(), nullable=True), + sa.Column("endpoint", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("input_tokens", sa.Integer(), nullable=True), + sa.Column("output_tokens", sa.Integer(), nullable=True), + sa.Column("cached_tokens", sa.Integer(), nullable=True), + sa.Column("reasoning_tokens", sa.Integer(), nullable=True), + + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["provider_key_id"], ["provider_keys.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["forge_key_id"], ["forge_api_keys.id"], ondelete="CASCADE"), + ) + + +def downgrade() -> None: + op.drop_table("usage_tracker") diff --git a/alembic/versions/831fc2cf16ee_enable_soft_deletion_for_provider_keys_.py b/alembic/versions/831fc2cf16ee_enable_soft_deletion_for_provider_keys_.py new file mode 100644 index 0000000..0817aaf --- /dev/null +++ b/alembic/versions/831fc2cf16ee_enable_soft_deletion_for_provider_keys_.py @@ -0,0 +1,28 @@ +"""enable soft deletion for provider keys and api keys + +Revision ID: 831fc2cf16ee +Revises: 4a685a55c5cd +Create Date: 2025-08-02 17:50:12.224293 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '831fc2cf16ee' +down_revision = '4a685a55c5cd' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('provider_keys', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True)) + op.add_column('forge_api_keys', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True)) + op.alter_column('forge_api_keys', 'key', nullable=True) + + +def downgrade() -> None: + op.drop_column('provider_keys', 'deleted_at') + op.drop_column('forge_api_keys', 'deleted_at') + op.alter_column('forge_api_keys', 'key', nullable=False) diff --git a/app/api/dependencies.py b/app/api/dependencies.py index 4cbf455..5a639e5 100644 --- a/app/api/dependencies.py +++ b/app/api/dependencies.py @@ -1,6 +1,7 @@ import asyncio import contextlib import json +from typing import Any # Add environment variables for Clerk import os @@ -147,11 +148,12 @@ async def get_api_key_from_headers(request: Request) -> str: status_code=status.HTTP_401_UNAUTHORIZED, detail="API key not found in headers", ) - + async def get_user_by_api_key( request: Request = None, db: AsyncSession = Depends(get_async_db), + include_api_key_id: bool = False, ) -> User: """Get user by API key from headers, with caching""" api_key_from_header = await get_api_key_from_headers(request) @@ -185,7 +187,22 @@ async def get_user_by_api_key( # Return a transient User object from cached data, not a managed one. # This avoids the db.merge() call and its expensive SELECT query. # Downstream code can access attributes, but not lazy-load relationships. - return User(**cached_user.model_dump()) + if not include_api_key_id: + return User(**cached_user.model_dump()) + else: + result = await db.execute( + select(ForgeApiKey) + .options(selectinload(ForgeApiKey.allowed_provider_keys)) + .filter(ForgeApiKey.key == api_key_from_header, ForgeApiKey.is_active, ForgeApiKey.deleted_at == None) + ) + api_key_record = result.scalar_one_or_none() + + if not api_key_record: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + return User(**cached_user.model_dump()), api_key_record.id # Try scope cache first – this doesn't remove the need to verify the key, but it # avoids an extra query later in /models. @@ -194,7 +211,7 @@ async def get_user_by_api_key( result = await db.execute( select(ForgeApiKey) .options(selectinload(ForgeApiKey.allowed_provider_keys)) - .filter(ForgeApiKey.key == api_key_from_header, ForgeApiKey.is_active) + .filter(ForgeApiKey.key == api_key_from_header, ForgeApiKey.is_active, ForgeApiKey.deleted_at == None) ) api_key_record = result.scalar_one_or_none() @@ -243,9 +260,24 @@ async def get_user_by_api_key( # Cache the user data for future requests await cache_user_async(api_key, user) + if include_api_key_id: + return user, api_key_record.id return user +async def get_user_details_by_api_key( + request: Request = None, + db: AsyncSession = Depends(get_async_db), +) -> dict[str, Any]: + """Get user details by API key from headers, with caching""" + user, api_key_id = await get_user_by_api_key(request, db, include_api_key_id=True) + + return { + "user": user, + "api_key_id": api_key_id, + } + + async def validate_clerk_jwt(token: str = Depends(clerk_token_header)): """ Validate a Clerk JWT token using JWKS from Clerk. diff --git a/app/api/routes/api_keys.py b/app/api/routes/api_keys.py index 27f1642..31204e8 100644 --- a/app/api/routes/api_keys.py +++ b/app/api/routes/api_keys.py @@ -1,7 +1,8 @@ from typing import Any +from datetime import UTC, datetime from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -18,7 +19,7 @@ from app.core.async_cache import invalidate_forge_scope_cache_async, invalidate_user_cache_async, invalidate_provider_service_cache_async from app.core.database import get_async_db from app.core.security import generate_forge_api_key -from app.models.forge_api_key import ForgeApiKey +from app.models.forge_api_key import ForgeApiKey, forge_api_key_provider_scope_association from app.models.provider_key import ProviderKey as ProviderKeyModel from app.models.user import User as UserModel @@ -36,7 +37,7 @@ async def _get_api_keys_internal( result = await db.execute( select(ForgeApiKey) .options(selectinload(ForgeApiKey.allowed_provider_keys)) - .filter(ForgeApiKey.user_id == current_user.id) + .filter(ForgeApiKey.user_id == current_user.id, ForgeApiKey.deleted_at == None) ) api_keys = result.scalars().all() @@ -71,6 +72,7 @@ async def _create_api_key_internal( select(ProviderKeyModel).filter( ProviderKeyModel.id.in_(api_key_create.allowed_provider_key_ids), ProviderKeyModel.user_id == current_user.id, + ProviderKeyModel.deleted_at == None, ) ) allowed_providers = result.scalars().all() @@ -103,7 +105,7 @@ async def _update_api_key_internal( result = await db.execute( select(ForgeApiKey) .options(selectinload(ForgeApiKey.allowed_provider_keys)) - .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id) + .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id, ForgeApiKey.deleted_at == None) ) db_api_key = result.scalar_one_or_none() @@ -126,6 +128,7 @@ async def _update_api_key_internal( select(ProviderKeyModel).filter( ProviderKeyModel.id.in_(api_key_update.allowed_provider_key_ids), ProviderKeyModel.user_id == current_user.id, + ProviderKeyModel.deleted_at == None, ) ) allowed_providers = result.scalars().all() @@ -161,7 +164,7 @@ async def _delete_api_key_internal( result = await db.execute( select(ForgeApiKey) .options(selectinload(ForgeApiKey.allowed_provider_keys)) - .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id) + .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id, ForgeApiKey.deleted_at == None) ) db_api_key = result.scalar_one_or_none() @@ -180,7 +183,16 @@ async def _delete_api_key_internal( "allowed_provider_key_ids": [pk.id for pk in db_api_key.allowed_provider_keys], } - await db.delete(db_api_key) + # do soft deletion here. Set the deleted_at column to the current time + db_api_key.deleted_at = datetime.now(UTC) + + # Delete the record from forge_api_key_provider_scope_association where forge_api_key_id matches current id + await db.execute( + delete(forge_api_key_provider_scope_association).where( + forge_api_key_provider_scope_association.c.forge_api_key_id == db_api_key.id + ) + ) + await db.commit() await invalidate_user_cache_async(key_to_invalidate) @@ -198,7 +210,7 @@ async def _regenerate_api_key_internal( result = await db.execute( select(ForgeApiKey) .options(selectinload(ForgeApiKey.allowed_provider_keys)) - .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id) + .filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id, ForgeApiKey.deleted_at == None) ) db_api_key = result.scalar_one_or_none() diff --git a/app/api/routes/provider_keys.py b/app/api/routes/provider_keys.py index d863bcf..d874713 100644 --- a/app/api/routes/provider_keys.py +++ b/app/api/routes/provider_keys.py @@ -1,10 +1,9 @@ -import json +from datetime import UTC, datetime from typing import Any from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession -from starlette import status from app.api.dependencies import ( get_current_active_user, @@ -16,10 +15,11 @@ ProviderKeyUpdate, ProviderKeyUpsertItem, ) -from app.core.async_cache import invalidate_provider_service_cache_async +from app.core.async_cache import invalidate_provider_service_cache_async, invalidate_forge_scope_cache_async from app.core.database import get_async_db from app.core.logger import get_logger from app.core.security import decrypt_api_key, encrypt_api_key +from app.models.forge_api_key import forge_api_key_provider_scope_association from app.models.provider_key import ProviderKey as ProviderKeyModel from app.models.user import User as UserModel from app.services.providers.adapter_factory import ProviderAdapterFactory @@ -56,7 +56,7 @@ async def _get_provider_keys_internal( Internal logic to get all provider keys for the current user. """ result = await db.execute( - select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id) + select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id, ProviderKeyModel.deleted_at == None) ) provider_keys = result.scalars().all() return [ProviderKey.model_validate(pk) for pk in provider_keys] @@ -94,6 +94,7 @@ async def _create_provider_key_internal( select(ProviderKeyModel).filter( ProviderKeyModel.user_id == current_user.id, ProviderKeyModel.provider_name == provider_key_create.provider_name, + ProviderKeyModel.deleted_at == None, ) ) existing_key = result.scalar_one_or_none() @@ -148,6 +149,7 @@ async def _update_provider_key_internal( select(ProviderKeyModel).filter( ProviderKeyModel.provider_name == provider_name, ProviderKeyModel.user_id == current_user.id, + ProviderKeyModel.deleted_at == None, ) ) db_provider_key = result.scalar_one_or_none() @@ -175,6 +177,7 @@ async def _process_provider_key_delete_data( select(ProviderKeyModel).filter( ProviderKeyModel.provider_name == provider_name, ProviderKeyModel.user_id == user_id, + ProviderKeyModel.deleted_at == None, ) ) db_provider_key = result.scalar_one_or_none() @@ -185,9 +188,18 @@ async def _process_provider_key_delete_data( # Store the provider key data before deletion provider_key_data = ProviderKey.model_validate(db_provider_key) - await db.delete(db_provider_key) + # do soft deletion here. Set the deleted_at column to the current time + db_provider_key.deleted_at = datetime.now(UTC) + + # Delete the record from forge_api_key_provider_scope_association where provider_key_id matches current id + await db.execute( + delete(forge_api_key_provider_scope_association).where( + forge_api_key_provider_scope_association.c.provider_key_id == db_provider_key.id, + ) + ) + scoped_forge_api_keys = db_provider_key.scoped_forge_api_keys - return provider_key_data + return provider_key_data, [scoped_forge_api_key.key for scoped_forge_api_key in scoped_forge_api_keys] async def _delete_provider_key_internal( @@ -196,11 +208,13 @@ async def _delete_provider_key_internal( """ Internal logic to delete a provider key for the current user. """ - provider_key_data = await _process_provider_key_delete_data(db, provider_name, current_user.id) + provider_key_data, scoped_forge_api_keys = await _process_provider_key_delete_data(db, provider_name, current_user.id) await db.commit() # Invalidate caches after deleting a provider key await invalidate_provider_service_cache_async(current_user.id) + for scoped_forge_api_key in scoped_forge_api_keys: + await invalidate_forge_scope_cache_async(scoped_forge_api_key) return provider_key_data @@ -302,13 +316,14 @@ async def _batch_upsert_provider_keys_internal( # 1. Fetch all existing keys for the user result = await db.execute( - select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id) + select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id, ProviderKeyModel.deleted_at == None) ) existing_keys_query = result.scalars().all() # 2. Map them by provider_name for efficient lookup existing_keys_map: dict[str, ProviderKeyModel] = { key.provider_name: key for key in existing_keys_query } + invalidated_forge_api_keys = set() for item in items: if "****" in item.api_key: @@ -320,7 +335,8 @@ async def _batch_upsert_provider_keys_internal( # Handle deletion if api_key is "DELETE" if item.api_key == "DELETE": if existing_provider_key: - await _process_provider_key_delete_data(db, item.provider_name, current_user.id) + _, scoped_forge_api_keys = await _process_provider_key_delete_data(db, item.provider_name, current_user.id) + invalidated_forge_api_keys.update(scoped_forge_api_keys) processed = True elif existing_provider_key: # Update existing key db_key_to_process = await _process_provider_key_update_data(existing_provider_key, ProviderKeyUpdate.model_validate(item.model_dump(exclude_unset=True))) @@ -357,6 +373,8 @@ async def _batch_upsert_provider_keys_internal( await db.refresh(key) # Refresh each key to get DB-generated values like id, timestamps processed_keys = [ProviderKey.model_validate(key) for key in processed_keys] await invalidate_provider_service_cache_async(current_user.id) + for key in invalidated_forge_api_keys: + await invalidate_forge_scope_cache_async(key) except Exception as e: await db.rollback() error_message_prefix = "Error during final commit/refresh in batch upsert" diff --git a/app/api/routes/proxy.py b/app/api/routes/proxy.py index 82384b6..782ee88 100644 --- a/app/api/routes/proxy.py +++ b/app/api/routes/proxy.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import selectinload from starlette.responses import StreamingResponse -from app.api.dependencies import get_user_by_api_key +from app.api.dependencies import get_user_by_api_key, get_user_details_by_api_key from app.api.schemas.openai import ( ChatCompletionRequest, CompletionRequest, @@ -16,7 +16,7 @@ ImageEditsRequest, ImageGenerationRequest, ) -from app.core.async_cache import async_provider_service_cache +from app.core.async_cache import forge_scope_cache_async, get_forge_scope_cache_async from app.core.database import get_async_db from app.core.logger import get_logger from app.models.forge_api_key import ForgeApiKey @@ -46,13 +46,13 @@ async def _get_allowed_provider_names( if allowed is not None: return allowed - allowed = await async_provider_service_cache.get(f"forge_scope:{api_key}") + allowed = await get_forge_scope_cache_async(api_key) if allowed is None: result = await db.execute( select(ForgeApiKey) .options(selectinload(ForgeApiKey.allowed_provider_keys)) - .filter(ForgeApiKey.key == f"forge-{api_key}", ForgeApiKey.is_active) + .filter(ForgeApiKey.key == f"forge-{api_key}", ForgeApiKey.is_active, ForgeApiKey.deleted_at == None) ) forge_key = result.scalar_one_or_none() if forge_key is None: @@ -60,9 +60,7 @@ async def _get_allowed_provider_names( status_code=401, detail="Forge API key not found or inactive" ) allowed = [pk.provider_name for pk in forge_key.allowed_provider_keys] - await async_provider_service_cache.set( - f"forge_scope:{api_key}", allowed, ttl=300 - ) + await forge_scope_cache_async(api_key, allowed) request.state.allowed_provider_names = allowed return allowed @@ -72,7 +70,7 @@ async def _get_allowed_provider_names( async def create_chat_completion( request: Request, chat_request: ChatCompletionRequest, - user: User = Depends(get_user_by_api_key), + user_details: dict[str, Any] = Depends(get_user_details_by_api_key), db: AsyncSession = Depends(get_async_db), ) -> Any: """ @@ -80,7 +78,9 @@ async def create_chat_completion( """ try: # Get cached provider service instance - provider_service = await ProviderService.async_get_instance(user, db) + user = user_details["user"] + api_key_id = user_details["api_key_id"] + provider_service = await ProviderService.async_get_instance(user, db, api_key_id=api_key_id) # Convert to dict and extract request properties payload = chat_request.dict(exclude_unset=True) @@ -110,8 +110,10 @@ async def create_chat_completion( # Otherwise, return the JSON response directly return response except ValueError as err: + logger.exception(f"Error processing chat completion request: {str(err)}") raise HTTPException(status_code=400, detail=str(err)) from err except Exception as err: + logger.exception(f"Error processing chat completion request: {str(err)}") raise HTTPException( status_code=500, detail=f"Error processing request: {str(err)}" ) from err @@ -121,14 +123,16 @@ async def create_chat_completion( async def create_completion( request: Request, completion_request: CompletionRequest, - user: User = Depends(get_user_by_api_key), + user_details: dict[str, Any] = Depends(get_user_details_by_api_key), db: AsyncSession = Depends(get_async_db), ) -> Any: """ Create a completion (OpenAI-compatible endpoint). """ try: - provider_service = await ProviderService.async_get_instance(user, db) + user = user_details["user"] + api_key_id = user_details["api_key_id"] + provider_service = await ProviderService.async_get_instance(user, db, api_key_id=api_key_id) allowed_provider_names = await _get_allowed_provider_names(request, db) response = await provider_service.process_request( @@ -153,8 +157,10 @@ async def create_completion( # Otherwise, return the JSON response directly return response except ValueError as err: + logger.exception(f"Error processing completion request: {str(err)}") raise HTTPException(status_code=400, detail=str(err)) from err except Exception as err: + logger.exception(f"Error processing completion request: {str(err)}") raise HTTPException( status_code=500, detail=f"Error processing request: {str(err)}" ) from err @@ -164,14 +170,16 @@ async def create_completion( async def create_image_generation( request: Request, image_generation_request: ImageGenerationRequest, - user: User = Depends(get_user_by_api_key), + user_details: dict[str, Any] = Depends(get_user_details_by_api_key), db: AsyncSession = Depends(get_async_db), ) -> Any: """ Create an image generation (OpenAI-compatible endpoint). """ try: - provider_service = await ProviderService.async_get_instance(user, db) + user = user_details["user"] + api_key_id = user_details["api_key_id"] + provider_service = await ProviderService.async_get_instance(user, db, api_key_id=api_key_id) payload = image_generation_request.model_dump(mode="json", exclude_unset=True) @@ -195,11 +203,13 @@ async def create_image_generation( async def create_image_edits( request: Request, image_edits_request: ImageEditsRequest, - user: User = Depends(get_user_by_api_key), + user_details: dict[str, Any] = Depends(get_user_details_by_api_key), db: AsyncSession = Depends(get_async_db), ) -> Any: try: - provider_service = await ProviderService.async_get_instance(user, db) + user = user_details["user"] + api_key_id = user_details["api_key_id"] + provider_service = await ProviderService.async_get_instance(user, db, api_key_id=api_key_id) payload = image_edits_request.model_dump(mode="json", exclude_unset=True) allowed_provider_names = await _get_allowed_provider_names(request, db) response = await provider_service.process_request( @@ -236,7 +246,7 @@ async def list_models( ) return {"object": "list", "data": models} except Exception as err: - logger.error(f"Error listing models: {str(err)}") + logger.exception(f"Error listing models: {str(err)}") raise HTTPException( status_code=500, detail=f"Error listing models: {str(err)}" ) from err @@ -247,14 +257,16 @@ async def list_models( async def create_embeddings( request: Request, embeddings_request: EmbeddingsRequest, - user: User = Depends(get_user_by_api_key), + user_details: dict[str, Any] = Depends(get_user_details_by_api_key), db: AsyncSession = Depends(get_async_db), ) -> Any: """ Create embeddings (OpenAI-compatible endpoint). """ try: - provider_service = await ProviderService.async_get_instance(user, db) + user = user_details["user"] + api_key_id = user_details["api_key_id"] + provider_service = await ProviderService.async_get_instance(user, db, api_key_id=api_key_id) payload = embeddings_request.model_dump(mode="json", exclude_unset=True) allowed_provider_names = await _get_allowed_provider_names(request, db) response = await provider_service.process_request( diff --git a/app/core/async_cache.py b/app/core/async_cache.py index 3775e83..408efb0 100644 --- a/app/core/async_cache.py +++ b/app/core/async_cache.py @@ -504,7 +504,7 @@ async def warm_cache_async(db: AsyncSession) -> None: # Get user's Forge API keys result = await db.execute( select(ForgeApiKey) - .filter(ForgeApiKey.user_id == user.id, ForgeApiKey.is_active) + .filter(ForgeApiKey.user_id == user.id, ForgeApiKey.is_active, ForgeApiKey.deleted_at == None) ) forge_api_keys = result.scalars().all() diff --git a/app/models/base.py b/app/models/base.py index e46b194..078b93c 100644 --- a/app/models/base.py +++ b/app/models/base.py @@ -1,3 +1,4 @@ +from datetime import UTC from datetime import datetime from sqlalchemy import Column, DateTime, Integer @@ -11,3 +12,4 @@ class BaseModel(Base): id = Column(Integer, primary_key=True, index=True) created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + deleted_at = Column(DateTime(timezone=True), nullable=True) diff --git a/app/models/forge_api_key.py b/app/models/forge_api_key.py index eeca194..d6774ac 100644 --- a/app/models/forge_api_key.py +++ b/app/models/forge_api_key.py @@ -38,6 +38,7 @@ class ForgeApiKey(Base): is_active = Column(Boolean, default=True) created_at = Column(DateTime, default=datetime.datetime.utcnow) last_used_at = Column(DateTime, nullable=True) + deleted_at = Column(DateTime(timezone=True), nullable=True) # Relationship to user user = relationship("User", back_populates="api_keys") diff --git a/app/models/usage_tracker.py b/app/models/usage_tracker.py new file mode 100644 index 0000000..161fec0 --- /dev/null +++ b/app/models/usage_tracker.py @@ -0,0 +1,23 @@ +import datetime +from datetime import UTC +import uuid + +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String +from sqlalchemy.dialects.postgresql import UUID +from .base import Base + +class UsageTracker(Base): + __tablename__ = "usage_tracker" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(Integer, ForeignKey("users.id", ondelete="CASCADE"), nullable=False) + provider_key_id = Column(Integer, ForeignKey("provider_keys.id", ondelete="CASCADE"), nullable=False) + forge_key_id = Column(Integer, ForeignKey("forge_api_keys.id", ondelete="CASCADE"), nullable=False) + model = Column(String, nullable=True) + endpoint = Column(String, nullable=True) + created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.datetime.now(UTC)) + updated_at = Column(DateTime(timezone=True), nullable=True) + input_tokens = Column(Integer, nullable=True) + output_tokens = Column(Integer, nullable=True) + cached_tokens = Column(Integer, nullable=True) + reasoning_tokens = Column(Integer, nullable=True) diff --git a/app/services/provider_service.py b/app/services/provider_service.py index 281050f..ca0c256 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -1,4 +1,5 @@ import asyncio +import uuid import inspect import json import os @@ -14,16 +15,32 @@ from app.core.security import decrypt_api_key, encrypt_api_key from app.exceptions.exceptions import InvalidProviderException, BaseInvalidRequestException, InvalidForgeKeyException from app.models.user import User -from app.services.usage_stats_service import UsageStatsService +from app.core.database import get_db_session from .providers.adapter_factory import ProviderAdapterFactory from .providers.base import ProviderAdapter +from .providers.usage_tracker_service import UsageTrackerService logger = get_logger(name="provider_service") # Add constants at the top of the file, after imports MODEL_PARTS_MIN_LENGTH = 2 # Minimum number of parts in a model name (e.g., "gpt-4") +# Create a background task to update the usage tracker that won't be cancelled +# Even if the streaming response is cancelled by client disconnect +async def update_usage_in_background(usage_tracker_id: uuid.UUID, input_tokens: int, output_tokens: int, cached_tokens: int, reasoning_tokens: int): + # Use a fresh DB session for logging, since the original request session + # may have been closed by FastAPI after the response was returned. + + async with get_db_session() as new_db_session: + await UsageTrackerService.update_usage_tracker( + db=new_db_session, + usage_tracker_id=usage_tracker_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + cached_tokens=cached_tokens, + reasoning_tokens=reasoning_tokens, + ) class ProviderService: """Service for handling provider API calls. @@ -59,18 +76,18 @@ def _model_cache_key(cls, provider_name: str, cache_key: str) -> str: # Using a stable namespace makes invalidation easier return f"models:{provider_name}:{cache_key}" - def __init__(self, user_id: int, db: AsyncSession): + def __init__(self, user_id: int, db: AsyncSession, api_key_id: int | None = None): self.user_id = user_id self.db = db + self.api_key_id = api_key_id self.provider_keys: dict[str, dict[str, Any]] = {} self.adapters = self._get_adapters() # Load provider keys on demand, not during initialization self._keys_loaded = False @classmethod - async def async_get_instance(cls, user: User, db: AsyncSession) -> "ProviderService": + async def async_get_instance(cls, user: User, db: AsyncSession, api_key_id: int | None = None) -> "ProviderService": """Get a cached instance of ProviderService for a user or create a new one (async version)""" - from app.core.async_cache import async_provider_service_cache cache_key = f"provider_service:{user.id}" cached_instance = await async_provider_service_cache.get(cache_key) @@ -81,6 +98,7 @@ async def async_get_instance(cls, user: User, db: AsyncSession) -> "ProviderServ ) # Update the db session reference for the cached instance cached_instance.db = db + cached_instance.api_key_id = api_key_id return cached_instance # No cached instance found, create a new one (async) @@ -88,7 +106,7 @@ async def async_get_instance(cls, user: User, db: AsyncSession) -> "ProviderServ logger.debug( f"Creating new ProviderService instance for user {user.id} (async)" ) - instance = cls(user.id, db) + instance = cls(user.id, db, api_key_id) await async_provider_service_cache.set(cache_key, instance) return instance @@ -164,7 +182,7 @@ async def _load_provider_keys(self) -> dict[str, dict[str, Any]]: # Query ProviderKey directly by user_id from app.models.provider_key import ProviderKey - result = await self.db.execute(select(ProviderKey).filter(ProviderKey.user_id == self.user_id)) + result = await self.db.execute(select(ProviderKey).filter(ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None)) provider_key_records = result.scalars().all() keys = {} @@ -172,6 +190,7 @@ async def _load_provider_keys(self) -> dict[str, dict[str, Any]]: model_mapping = provider_key.model_mapping or {} keys[provider_key.provider_name] = { + "id": provider_key.id, "api_key": decrypt_api_key(provider_key.encrypted_api_key), "base_url": provider_key.base_url, "model_mapping": model_mapping, @@ -195,8 +214,6 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]: return self.provider_keys # Try to get provider keys from cache - from app.core.async_cache import async_provider_service_cache - cache_key = f"provider_keys:{self.user_id}" cached_keys = await async_provider_service_cache.get(cache_key) if cached_keys is not None: @@ -216,7 +233,7 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]: # Query ProviderKey directly by user_id from app.models.provider_key import ProviderKey - result = await self.db.execute(select(ProviderKey).filter(ProviderKey.user_id == self.user_id)) + result = await self.db.execute(select(ProviderKey).filter(ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None)) provider_key_records = result.scalars().all() keys = {} @@ -224,6 +241,7 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]: model_mapping = provider_key.model_mapping or {} keys[provider_key.provider_name] = { + "id": provider_key.id, "api_key": decrypt_api_key(provider_key.encrypted_api_key), "base_url": provider_key.base_url, "model_mapping": model_mapping, @@ -286,11 +304,13 @@ def _get_provider_info_with_prefix( provider_data = self.provider_keys[matching_provider] model_mapping = provider_data.get("model_mapping", {}) + provider_key_id = provider_data.get("id") mapped_model = model_mapping.get(model_name, model_name) return ( matching_provider, mapped_model, provider_data.get("base_url"), + provider_key_id, ) def _find_provider_for_unprefixed_model( @@ -309,12 +329,14 @@ def _find_provider_for_unprefixed_model( # Check custom model mappings for provider_name, provider_data in sorted_providers: model_mapping = provider_data.get("model_mapping", {}) + provider_key_id = provider_data.get("id") if model in model_mapping: mapped_model = model_mapping[model] return ( provider_name, mapped_model, provider_data.get("base_url"), + provider_key_id, ) logger.error(f"No matching provider found for {model}") @@ -452,7 +474,7 @@ async def process_request( error=ValueError(error_message) ) - provider_name, actual_model, base_url = self._get_provider_info(model) + provider_name, actual_model, base_url, provider_key_id = self._get_provider_info(model) # Enforce scope restriction (case-insensitive). if allowed_provider_names is not None and ( @@ -485,6 +507,22 @@ async def process_request( adapter = ProviderAdapterFactory.get_adapter(provider_name, base_url, config) # Process the request through the adapter + usage_tracker_id = None + if self.api_key_id is not None and provider_key_id is not None: + usage_tracker_id = await UsageTrackerService.start_tracking_usage( + db=self.db, + user_id=self.user_id, + provider_key_id=provider_key_id, + forge_key_id=self.api_key_id, + model=actual_model, + endpoint=endpoint, + ) + else: + # TODO: this shouldn't happen, but we handle it gracefully as we don't want to break the flow + # Dive deeper into this if it ever happens + logger.info(f"api_key_id: {self.api_key_id}, provider_key_id: {provider_key_id}") + logger.warning("No API key ID or provider key ID found, skipping usage tracking") + if "completion" in endpoint: result = await adapter.process_completion( endpoint, @@ -522,6 +560,11 @@ async def process_request( else: error_message = f"Unsupported endpoint: {endpoint}" logger.error(error_message) + # Delete the usage tracker record if it exists + await UsageTrackerService.delete_usage_tracker_record( + db=self.db, + usage_tracker_id=usage_tracker_id, + ) raise NotImplementedError(error_message) # Track usage statistics if it's not a streaming response @@ -529,34 +572,28 @@ async def process_request( # Extract usage data from the response input_tokens = 0 output_tokens = 0 + cached_tokens = 0 + reasoning_tokens = 0 + # https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage if isinstance(result, dict) and "usage" in result: usage = result.get("usage", {}) input_tokens = usage.get("prompt_tokens", 0) output_tokens = usage.get("completion_tokens", 0) + cached_tokens = usage.get("prompt_tokens_details", {}).get("cached_tokens", 0) + reasoning_tokens = usage.get("completion_tokens_details", {}).get("reasoning_tokens", 0) - # Record the usage statistics using the new logging method - # Use a fresh DB session for logging, since the original request session - # may have been closed by FastAPI after the response was returned. - from app.core.database import get_db_session - - async with get_db_session() as new_db_session: - await UsageStatsService.log_api_request( - db=new_db_session, - user_id=self.user_id, - provider_name=provider_name, - model=actual_model, - endpoint=endpoint, - input_tokens=input_tokens, - output_tokens=output_tokens, - ) - return result + asyncio.create_task(update_usage_in_background(usage_tracker_id, input_tokens, output_tokens, cached_tokens, reasoning_tokens)) + return result else: # For streaming responses, wrap the generator to count tokens async def token_counting_stream() -> AsyncGenerator[bytes, None]: + approximate_input_tokens = 0 + approximate_output_tokens = 0 output_tokens = 0 input_tokens = 0 - has_final_usage = False + cached_tokens = 0 + reasoning_tokens = 0 chunks_processed = 0 # Get the streaming mode from the payload @@ -568,12 +605,11 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: messages = payload.get("messages", []) # Rough estimate of input tokens based on message length - # This will be replaced with actual usage data if available in the final chunk for msg in messages: content = msg.get("content", "") if isinstance(content, str): # Rough approximation: 4 chars ~= 1 token - input_tokens += len(content) // 4 + approximate_input_tokens += len(content) // 4 try: async for chunk in result: @@ -595,18 +631,20 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: data = json.loads(data_str) # Some providers include usage info in the last chunk + # https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-usage if "usage" in data and data["usage"]: logger.debug( f"Found usage data in chunk: {data['usage']}" ) usage = data.get("usage", {}) - input_tokens = usage.get( - "prompt_tokens", input_tokens - ) - output_tokens = usage.get( - "completion_tokens", output_tokens - ) - has_final_usage = True + input_tokens += usage.get( + "prompt_tokens", 0 + ) or 0 + output_tokens += usage.get( + "completion_tokens", 0 + ) or 0 + cached_tokens += usage.get("prompt_tokens_details", {}).get("cached_tokens", 0) or 0 + reasoning_tokens += usage.get("completion_tokens_details", {}).get("reasoning_tokens", 0) or 0 # Extract content from the chunk based on OpenAI format if "choices" in data: @@ -617,9 +655,9 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: ): content = choice["delta"]["content"] # Only count tokens if we don't have final usage data - if not has_final_usage and content: + if content: # Count tokens in content (approx) - output_tokens += len(content) // 4 + approximate_output_tokens += len(content) // 4 except json.JSONDecodeError: # If JSON parsing fails, just continue pass @@ -640,25 +678,11 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: logger.debug( f"Logging API request final details: provider={provider_name}, " f"model={actual_model}, input_tokens={input_tokens}, " - f"output_tokens={output_tokens}" + f"output_tokens={output_tokens}, cached_tokens={cached_tokens}, reasoning_tokens={reasoning_tokens}" ) - # Use a fresh DB session for logging, since the original request session - # may have been closed by FastAPI after the response was returned. - from app.core.database import get_db_session - - async with get_db_session() as new_db_session: - await UsageStatsService.log_api_request( - db=new_db_session, - user_id=self.user_id, - provider_name=provider_name, - model=actual_model, - endpoint=endpoint, - input_tokens=input_tokens, - output_tokens=output_tokens, - ) - - # End of token_counting_stream function + asyncio.create_task(update_usage_in_background(usage_tracker_id, input_tokens or approximate_input_tokens, output_tokens or approximate_output_tokens, cached_tokens, reasoning_tokens)) + return token_counting_stream() diff --git a/app/services/providers/anthropic_adapter.py b/app/services/providers/anthropic_adapter.py index 253fb35..f6f2580 100644 --- a/app/services/providers/anthropic_adapter.py +++ b/app/services/providers/anthropic_adapter.py @@ -3,7 +3,7 @@ import uuid from collections.abc import AsyncGenerator from http import HTTPStatus -from typing import Any, Callable +from typing import Any, Callable, Tuple import aiohttp @@ -37,7 +37,25 @@ def __init__( @property def provider_name(self) -> str: return self._provider_name - + + @staticmethod + def format_anthropic_usage(usage_data: dict[str, Any]) -> dict[str, Any]: + if not usage_data: + return None + + input_tokens = usage_data.get("input_tokens", 0) + output_tokens = usage_data.get("output_tokens", 0) + cached_tokens = usage_data.get("cache_creation_input_tokens", 0) or 0 + cached_tokens += usage_data.get("cache_read_input_tokens", 0) or 0 + return { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, + "prompt_tokens_details": { + "cached_tokens": cached_tokens, + }, + } + @staticmethod def convert_openai_image_content_to_anthropic( msg: dict[str, Any], @@ -59,6 +77,42 @@ def convert_openai_image_content_to_anthropic( } else: return {"type": "image", "source": {"type": "url", "url": data_url}} + + @staticmethod + def translate_anthropic_content_to_openai( + content: list[dict[str, Any]] | str | None, + ) -> Tuple[str, list[Any]]: + """Translate Anthropic content to OpenAI content""" + if content is None: + return "", [] + + if isinstance(content, str): + return content, [] + + text_content = "" + tool_calls = [] + for block in content: + if not block or not isinstance(block, dict): + continue + + if block.get("type") == "text": + text_content += block.get("text", "") + elif block.get("type") == "tool_use": + # Convert Anthropic tool use to OpenAI tool call format + tool_calls.append( + { + "id": block.get( + "id", f"call_{uuid.uuid4().hex[:8]}" + ), + "type": "function", + "function": { + "name": block.get("name", ""), + "arguments": json.dumps(block.get("input", {})), + }, + } + ) + return text_content, tool_calls + @staticmethod def convert_openai_content_to_anthropic( @@ -374,11 +428,9 @@ async def stream_anthropic_response( ): """Handle streaming response from Anthropic API, including usage data.""" + # https://docs.anthropic.com/en/docs/build-with-claude/streaming#full-http-stream-response async def stream_response() -> AsyncGenerator[bytes, None]: # Store parts of usage info as they arrive - captured_input_tokens = 0 - captured_output_tokens = 0 - usage_info_complete = False # Flag to check if both are found request_id = f"chatcmpl-{uuid.uuid4()}" async with ( @@ -417,23 +469,37 @@ async def stream_response() -> AsyncGenerator[bytes, None]: try: data = json.loads(data_str) openai_chunk = None + usage_data = None finish_reason = None # --- Event Processing Logic --- # Capture Input Tokens from message_start if event_type == "message_start": message_data = data.get("message", {}) - if "usage" in message_data: - captured_input_tokens = message_data["usage"].get( - "input_tokens", 0 - ) - captured_output_tokens = message_data["usage"].get( - "output_tokens", captured_output_tokens - ) + usage_data = AnthropicAdapter.format_anthropic_usage(message_data.get("usage", {})) + if message_data: + openai_chunk = { + "id": request_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "index": 0, + "delta": { + "role": message_data.get("role", "assistant"), + "content": AnthropicAdapter.translate_anthropic_content_to_openai(message_data.get("content", []))[0], + }, + "finish_reason": None, + } + ], + } elif event_type == "content_block_start": # Handle start of content blocks (text or tool_use) content_block = data.get("content_block", {}) + + usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {})) if content_block.get("type") == "tool_use": # Start of a tool call openai_chunk = { @@ -471,6 +537,8 @@ async def stream_response() -> AsyncGenerator[bytes, None]: elif event_type == "content_block_delta": delta = data.get("delta", {}) + + usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {})) if delta.get("type") == "text_delta": # Text content delta delta_content = delta.get("text", "") @@ -520,6 +588,9 @@ async def stream_response() -> AsyncGenerator[bytes, None]: # Capture Output Tokens & Finish Reason from message_delta elif event_type == "message_delta": delta_data = data.get("delta", {}) + + usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {})) + anthropic_stop_reason = delta_data.get("stop_reason") if anthropic_stop_reason: # Map Anthropic stop reason to OpenAI finish reason @@ -533,15 +604,6 @@ async def stream_response() -> AsyncGenerator[bytes, None]: anthropic_stop_reason, "stop" ) - # Check for usage at the TOP LEVEL of the message_delta event data - if "usage" in data: - usage_data_in_delta = data["usage"] - captured_output_tokens = usage_data_in_delta.get( - "output_tokens", captured_output_tokens - ) - if captured_input_tokens > 0: - usage_info_complete = True - # Capture Finish Reason from message_stop (backup for usage) elif event_type == "message_stop": # Map Anthropic stop reason to OpenAI finish reason if not already set @@ -559,19 +621,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]: anthropic_stop_reason, "stop" ) - if not usage_info_complete and "usage" in data: - usage = data["usage"] - captured_input_tokens = usage.get( - "input_tokens", captured_input_tokens - ) - captured_output_tokens = usage.get( - "output_tokens", captured_output_tokens - ) - if ( - captured_input_tokens > 0 - and captured_output_tokens > 0 - ): - usage_info_complete = True + usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {})) # --- Yielding Logic --- if openai_chunk: @@ -579,31 +629,30 @@ async def stream_response() -> AsyncGenerator[bytes, None]: openai_chunk["choices"][0]["finish_reason"] = ( finish_reason ) - yield f"data: {json.dumps(openai_chunk)}\n\n".encode() - # Check if usage info is complete *after* potential content chunk - if usage_info_complete: - final_usage_data = { - "prompt_tokens": captured_input_tokens, - "completion_tokens": captured_output_tokens, - "total_tokens": captured_input_tokens - + captured_output_tokens, - } - usage_chunk = { + if usage_data: + openai_chunk["usage"] = usage_data + + yield f"data: {json.dumps(openai_chunk)}\n\n".encode() + elif usage_data: + # yield the usage chunk + openai_chunk = { "id": request_id, "object": "chat.completion.chunk", "created": int(time.time()), "model": model_name, "choices": [{"index": 0, "delta": {}}], - "usage": final_usage_data, + "usage": usage_data, } - yield f"data: {json.dumps(usage_chunk)}\n\n".encode() - # Reset flag to prevent duplicate yields - usage_info_complete = False + if finish_reason: + openai_chunk["choices"][0]["finish_reason"] = ( + finish_reason + ) + yield f"data: {json.dumps(openai_chunk)}\n\n".encode() except json.JSONDecodeError as e: logger.warning( - f"Stream API error for {self.provider_name}: Failed to parse JSON: {e}" + f"Stream API error for Anthropic Base: Failed to parse JSON: {e}" ) continue except Exception as e: @@ -646,31 +695,7 @@ async def process_regular_response( if "messages" in anthropic_payload: # Messages API response content = anthropic_response.get("content", []) - text_content = "" - tool_calls = [] - - # Extract text and tool calls from content blocks - if content: # Ensure content is not None - for block in content: - if not block or not isinstance(block, dict): - continue - - if block.get("type") == "text": - text_content += block.get("text", "") - elif block.get("type") == "tool_use": - # Convert Anthropic tool use to OpenAI tool call format - tool_calls.append( - { - "id": block.get( - "id", f"call_{uuid.uuid4().hex[:8]}" - ), - "type": "function", - "function": { - "name": block.get("name", ""), - "arguments": json.dumps(block.get("input", {})), - }, - } - ) + text_content, tool_calls = AnthropicAdapter.translate_anthropic_content_to_openai(content) # Map Anthropic stop reason to OpenAI finish reason stop_reason = anthropic_response.get("stop_reason", "end_turn") @@ -696,12 +721,8 @@ async def process_regular_response( None # OpenAI expects null content when tool calls are present ) - input_tokens = anthropic_response.get("usage", {}).get( - "input_tokens", 0 - ) - output_tokens = anthropic_response.get("usage", {}).get( - "output_tokens", 0 - ) + # https://docs.anthropic.com/en/api/messages#response-usage + usage_data = AnthropicAdapter.format_anthropic_usage(anthropic_response.get("usage", {})) return { "id": completion_id, "object": "chat.completion", @@ -714,11 +735,7 @@ async def process_regular_response( "finish_reason": finish_reason, } ], - "usage": { - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - "total_tokens": input_tokens + output_tokens, - }, + **({"usage": usage_data} if usage_data else {}), } else: # Legacy completion response @@ -738,9 +755,9 @@ async def process_regular_response( } ], "usage": { - "prompt_tokens": -1, - "completion_tokens": -1, - "total_tokens": -1, + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, }, } diff --git a/app/services/providers/bedrock_adapter.py b/app/services/providers/bedrock_adapter.py index 63a3f5d..038acd9 100644 --- a/app/services/providers/bedrock_adapter.py +++ b/app/services/providers/bedrock_adapter.py @@ -111,7 +111,22 @@ def mask_config(config: dict[str, Any] | None) -> dict[str, Any] | None: "aws_access_key_id": config["aws_access_key_id"][:3] + mask_str + config["aws_access_key_id"][-3:], "aws_secret_access_key": config["aws_secret_access_key"][:3] + mask_str + config["aws_secret_access_key"][-3:], } - + + @staticmethod + def format_bedrock_usage(usage_data: dict[str, Any]) -> dict[str, Any]: + """Format Bedrock usage data to OpenAI format""" + if not usage_data: + return None + input_tokens = usage_data.get("inputTokens", 0) + output_tokens = usage_data.get("outputTokens", 0) + total_tokens = usage_data.get("totalTokens", 0) or (input_tokens + output_tokens) + cached_tokens = usage_data.get("cacheReadInputTokens", 0) + usage_data.get("cacheWriteInputTokens", 0) + return { + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": total_tokens, + "prompt_tokens_details": {"cached_tokens": cached_tokens}, + } async def list_models(self, api_key: str) -> list[str]: """List all models (verbosely) supported by the provider""" @@ -305,9 +320,7 @@ async def _process_regular_response(self, bedrock_payload: dict[str, Any]) -> di error_message=error_text ) - input_tokens = response.get("usage", {}).get("inputTokens", 0) - output_tokens = response.get("usage", {}).get("outputTokens", 0) - total_tokens = response.get("usage", {}).get("totalTokens", 0) + usage_data = self.format_bedrock_usage(response.get("usage", {})) finish_reason = response.get("stopReason", "end_turn") finish_reason = self.BEDROCK_FINISH_REASONS_MAPPING.get(finish_reason, "stop") @@ -327,16 +340,12 @@ async def _process_regular_response(self, bedrock_payload: dict[str, Any]) -> di "finish_reason": finish_reason, } ], - "usage": { - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - "total_tokens": total_tokens or (input_tokens + output_tokens), - }, + **({"usage": usage_data} if usage_data else {}), } async def _process_streaming_response(self, bedrock_payload: dict[str, Any]) -> AsyncGenerator[bytes, None]: """Process a streaming response from Bedrock API""" - final_usage_data = None + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse_stream.html finish_reason = None request_id = f"chatcmpl-{uuid.uuid4()}" created = int(time.time()) @@ -350,6 +359,7 @@ async def _process_streaming_response(self, bedrock_payload: dict[str, Any]) -> async for event in response["stream"]: # only one key in each event openai_chunk = None + usage_data = None if "messageStart" in event: role = event["messageStart"].get("role", "assistant") openai_chunk = { @@ -416,17 +426,11 @@ async def _process_streaming_response(self, bedrock_payload: dict[str, Any]) -> ], } elif "metadata" in event: - usage = event["metadata"].get('usage') - if usage: - input_tokens = usage.get("inputTokens", 0) - output_tokens = usage.get("outputTokens", 0) - total_tokens = usage.get("totalTokens", 0) or (input_tokens + output_tokens) - final_usage_data = { - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - "total_tokens": total_tokens, - } + usage_data = self.format_bedrock_usage(event["metadata"].get('usage')) + if openai_chunk: + if usage_data: + openai_chunk["usage"] = usage_data yield f"data: {json.dumps(openai_chunk)}\n\n".encode() except Exception as e: error_text = f"Streaming completion API error for {self.provider_name}: {e}" @@ -436,19 +440,6 @@ async def _process_streaming_response(self, bedrock_payload: dict[str, Any]) -> error_code=400, error_message=error_text ) - if final_usage_data: - openai_chunk = { - "id": request_id, - "object": "chat.completion.chunk", - "created": created, - "model": bedrock_payload["modelId"], - "choices": [{ - "index": 0, - "delta": {}, - }], - "usage": final_usage_data, - } - yield f"data: {json.dumps(openai_chunk)}\n\n".encode() # Send final [DONE] message yield b"data: [DONE]\n\n" diff --git a/app/services/providers/cohere_adapter.py b/app/services/providers/cohere_adapter.py index 27e120e..811b935 100644 --- a/app/services/providers/cohere_adapter.py +++ b/app/services/providers/cohere_adapter.py @@ -61,17 +61,21 @@ async def list_models(self, api_key: str) -> list[str]: return models @staticmethod - def convert_usage_data(usage_metadata: dict[str, Any]) -> dict[str, Any]: + def convert_usage_data(usage_data: dict[str, Any]) -> dict[str, Any]: """Convert Cohere usage data to OpenAI format""" - # cohere only billed a specific amount of tokens - usage_metadata = usage_metadata or {} - billed_tokens = usage_metadata.get("billed_units", {}) - prompt_tokens = billed_tokens.get("input_tokens", 0) - completion_tokens = billed_tokens.get("output_tokens", 0) + if not usage_data: + return None + + billed_units = usage_data.get("billed_units", {}) + tokens = usage_data.get("tokens", {}) + + # prefer billed_units over tokens + input_tokens = billed_units.get("input_tokens") or tokens.get("input_tokens") or 0 + output_tokens = billed_units.get("output_tokens") or tokens.get("output_tokens") or 0 return { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "total_tokens": input_tokens + output_tokens, } def _convert_cohere_to_openai( @@ -111,9 +115,11 @@ def _convert_cohere_to_openai( ] # Set usage estimates if available - usage_metadata = cohere_response.get("usage", {}) - if usage_metadata: - openai_response["usage"] = self.convert_usage_data(usage_metadata) + # NOTE: The usage key should be usage instead of meta + # The doc is wrong: https://docs.cohere.com/docs/chat-api#response-structure + usage_data = self.convert_usage_data(cohere_response.get("usage", {})) + if usage_data: + openai_response["usage"] = usage_data return openai_response @@ -121,6 +127,7 @@ async def _stream_cohere_response( self, api_key: str, payload: dict[str, Any] ) -> AsyncGenerator[bytes, None]: """Stream a completion request using Cohere API""" + # https://docs.cohere.com/docs/streaming model = payload["model"] try: url = f"{self._base_url}/v2/chat" @@ -145,7 +152,6 @@ async def _stream_cohere_response( # Track the message ID for consistency message_id = None created = int(time.time()) - final_usage_data = None async for chunk in response.content.iter_chunks(): if not chunk[0]: # Skip empty chunks continue @@ -190,6 +196,7 @@ async def _stream_cohere_response( ) openai_chunk = None + usage_data = None # Convert to OpenAI format based on chunk type if chunk_type == "message-start": openai_chunk = { @@ -250,11 +257,6 @@ async def _stream_cohere_response( finish_reason = cohere_chunk.get("delta", {}).get( "finish_reason", "stop" ) - usage_metadata = cohere_chunk.get("usage", {}) - if usage_metadata: - final_usage_data = self.convert_usage_data( - usage_metadata - ) openai_chunk = { "id": message_id, @@ -269,25 +271,32 @@ async def _stream_cohere_response( } ], } + + if 'usage' in cohere_chunk: + usage_data = self.convert_usage_data(cohere_chunk['usage']) if openai_chunk: + if usage_data: + openai_chunk["usage"] = usage_data yield f"data: {json.dumps(openai_chunk)}\n\n".encode() + elif usage_data: + # Normally, only message-end event type has usage data. And we would yield a openai_chunk with usage data. + # This is just in case + usage_chunk = { + "id": message_id or uuid.uuid4(), + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [{"index": 0, "delta": {}}], + "usage": usage_data, + } + yield f"data: {json.dumps(usage_chunk)}\n\n".encode() except json.JSONDecodeError: logger.warning(f"Failed to parse Cohere chunk: {chunk[0]}") continue except Exception as e: logger.error(f"Error processing Cohere chunk: {e}") continue - if final_usage_data: - usage_chunk = { - "id": message_id or uuid.uuid4(), - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [], - "usage": final_usage_data, - } - yield f"data: {json.dumps(usage_chunk)}\n\n".encode() # # Send final [DONE] message yield b"data: [DONE]\n\n" diff --git a/app/services/providers/google_adapter.py b/app/services/providers/google_adapter.py index 1b5cbf3..2bcdb92 100644 --- a/app/services/providers/google_adapter.py +++ b/app/services/providers/google_adapter.py @@ -309,7 +309,6 @@ async def _stream_google_response( yield f"data: {json.dumps(initial_chunk)}\n\n".encode() request_id = f"chatcmpl-{uuid.uuid4()}" - final_usage_data = None # Store usage info when found try: if not google_payload: @@ -347,6 +346,7 @@ async def _stream_google_response( ) # Process response in chunks + # https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse buffer = "" async for chunk in response.content.iter_chunks(): if not chunk[0]: # Empty chunk @@ -358,6 +358,14 @@ async def _stream_google_response( while True: try: # Find the start of a JSON object + usage_data = None + openai_chunk = { + "id": request_id, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [{"index": 0, "delta": {"content": ""}}], + } start_idx = buffer.find("{") if start_idx == -1: break @@ -369,11 +377,12 @@ async def _stream_google_response( # Process the JSON object if "usageMetadata" in json_obj: - final_usage_data = self._format_google_usage( + usage_data = self.format_google_usage( json_obj["usageMetadata"] ) if "candidates" in json_obj: + choices = [] for c_idx, candidate in enumerate( json_obj.get("candidates", []) ): @@ -384,43 +393,27 @@ async def _stream_google_response( ) finish_reason = candidate.get("finishReason") - if text_content: - chunk = { - "id": request_id, - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": c_idx, - "delta": {"content": text_content}, - "finish_reason": finish_reason.lower() - if finish_reason - else None, - } - ], - } - yield f"data: {json.dumps(chunk)}\n\n".encode() - await asyncio.sleep( - 0.05 - ) # Small delay to prevent overwhelming the client + choices.append({ + "index": c_idx, + "delta": {"content": text_content}, + **({"finish_reason": finish_reason.lower() + if finish_reason + else {}}) + }) + if not choices: + choices = [{"index": 0, "delta": {"content": ""}}] + + openai_chunk["choices"] = choices + + if usage_data: + openai_chunk["usage"] = usage_data + + yield f"data: {json.dumps(openai_chunk)}\n\n".encode() except json.JSONDecodeError: # Incomplete JSON, wait for more data break - # Yield final usage chunk if data was found - if final_usage_data: - usage_chunk = { - "id": request_id, - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [{"index": 0, "delta": {}}], - "usage": final_usage_data, - } - yield f"data: {json.dumps(usage_chunk)}\n\n".encode() - # Send final [DONE] message yield b"data: [DONE]\n\n" @@ -437,18 +430,21 @@ async def _stream_google_response( yield f"data: {json.dumps(error_chunk)}\n\n".encode() yield b"data: [DONE]\n\n" - def _format_google_usage(self, metadata: dict) -> dict: + @staticmethod + def format_google_usage(metadata: dict) -> dict: """Format Google usage metadata to OpenAI format""" if not metadata: return None prompt_tokens = metadata.get("promptTokenCount", 0) completion_tokens = metadata.get("candidatesTokenCount", 0) + cached_tokens = metadata.get("cachedContentTokenCount", 0) + reasoning_tokens = metadata.get("thoughtsTokenCount", 0) return { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, - "total_tokens": metadata.get( - "totalTokenCount", prompt_tokens + completion_tokens - ), + "total_tokens": prompt_tokens + completion_tokens, + "prompt_tokens_details": {"cached_tokens": cached_tokens}, + "completion_tokens_details": {"reasoning_tokens": reasoning_tokens}, } @staticmethod @@ -558,13 +554,14 @@ def convert_google_completion_response_to_openai( google_response: dict[str, Any], model: str ) -> dict[str, Any]: """Convert Google completion response format to OpenAI format""" + # https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse openai_response = { "id": f"chatcmpl-{uuid.uuid4()}", "object": "chat.completion", "created": int(time.time()), "model": model, "choices": [], - "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "prompt_tokens_details": {"cached_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0}}, } # Extract the candidates @@ -606,16 +603,9 @@ def convert_google_completion_response_to_openai( ) # Set usage estimates if available - usage_metadata = google_response.get("usageMetadata", {}) - if usage_metadata: - prompt_tokens = usage_metadata.get("promptTokenCount", 0) - completion_tokens = usage_metadata.get("candidatesTokenCount", 0) - - openai_response["usage"] = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - } + usage_data = GoogleAdapter.format_google_usage(google_response.get("usageMetadata")) + if usage_data: + openai_response["usage"] = usage_data return openai_response diff --git a/app/services/providers/usage_tracker_service.py b/app/services/providers/usage_tracker_service.py new file mode 100644 index 0000000..3ad2496 --- /dev/null +++ b/app/services/providers/usage_tracker_service.py @@ -0,0 +1,82 @@ +from datetime import UTC +from datetime import datetime +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.exc import NoResultFound +import uuid + +from app.core.logger import get_logger +from app.models.usage_tracker import UsageTracker + +logger = get_logger(name="usage_tracker") + +class UsageTrackerService: + """Service for tracking usage of providers and forge API keys.""" + @staticmethod + async def start_tracking_usage( + db: AsyncSession, + user_id: int, + provider_key_id: int, + forge_key_id: int, + model: str, + endpoint: str, + ) -> int: + try: + usage_tracker = UsageTracker( + user_id=user_id, + provider_key_id=provider_key_id, + forge_key_id=forge_key_id, + model=model, + endpoint=endpoint, + created_at=datetime.now(UTC), + ) + db.add(usage_tracker) + await db.commit() + logger.debug(f"Started tracking usage for user {user_id} with provider {provider_key_id} and forge {forge_key_id} for model {model} and endpoint {endpoint}") + # return the id of the usage tracker + return usage_tracker.id + except Exception as e: + await db.rollback() + logger.error(f"Failed to track usage: {e}") + + @staticmethod + async def update_usage_tracker( + db: AsyncSession, + usage_tracker_id: uuid.UUID, + input_tokens: int, + output_tokens: int, + cached_tokens: int, + reasoning_tokens: int, + ) -> None: + if usage_tracker_id is None: + return + + try: + usage_tracker = await db.get_one(UsageTracker, usage_tracker_id) + usage_tracker.input_tokens = input_tokens + usage_tracker.output_tokens = output_tokens + usage_tracker.cached_tokens = cached_tokens + usage_tracker.reasoning_tokens = reasoning_tokens + usage_tracker.updated_at = datetime.now(UTC) + await db.commit() + logger.debug(f"Updated usage tracker {usage_tracker_id} with input_tokens {input_tokens}, output_tokens {output_tokens}, cached_tokens {cached_tokens}, reasoning_tokens {reasoning_tokens}") + except NoResultFound: + logger.error(f"Usage tracker not found: {usage_tracker_id}") + except Exception as e: + await db.rollback() + logger.error(f"Failed to update usage tracker: {e}") + + @staticmethod + async def delete_usage_tracker_record( + db: AsyncSession, + usage_tracker_id: uuid.UUID, + ) -> None: + if usage_tracker_id is None: + return + + try: + await db.execute(delete(UsageTracker).where(UsageTracker.id == usage_tracker_id)) + await db.commit() + except Exception as e: + await db.rollback() + logger.error(f"Failed to delete usage tracker record: {e}") \ No newline at end of file diff --git a/tests/mock_testing/add_mock_provider.py b/tests/mock_testing/add_mock_provider.py index 795ac91..9f03f2d 100755 --- a/tests/mock_testing/add_mock_provider.py +++ b/tests/mock_testing/add_mock_provider.py @@ -44,7 +44,7 @@ async def setup_mock_provider(username: str, force: bool = False): return False # Check if the mock provider already exists for this user - result = await db.execute(select(ProviderKey).filter(ProviderKey.user_id == user.id, ProviderKey.provider_name == "mock")) + result = await db.execute(select(ProviderKey).filter(ProviderKey.user_id == user.id, ProviderKey.provider_name == "mock", ProviderKey.deleted_at == None)) existing_provider = result.scalar_one_or_none() if existing_provider and not force: diff --git a/tests/mock_testing/setup_test_user.py b/tests/mock_testing/setup_test_user.py index defe1df..ae18b58 100644 --- a/tests/mock_testing/setup_test_user.py +++ b/tests/mock_testing/setup_test_user.py @@ -95,7 +95,7 @@ async def add_mock_provider_to_user(user_id): async with get_db_session() as db: try: # Check if the mock provider already exists for this user - result = await db.execute(select(ProviderKey).filter(ProviderKey.user_id == user_id, ProviderKey.provider_name == "mock")) + result = await db.execute(select(ProviderKey).filter(ProviderKey.user_id == user_id, ProviderKey.provider_name == "mock", ProviderKey.deleted_at == None)) existing_provider = result.scalar_one_or_none() if existing_provider: diff --git a/tests/unit_tests/test_anthropic_provider.py b/tests/unit_tests/test_anthropic_provider.py index 9b33c7b..c38b215 100644 --- a/tests/unit_tests/test_anthropic_provider.py +++ b/tests/unit_tests/test_anthropic_provider.py @@ -105,7 +105,7 @@ async def test_chat_completion_streaming(self): result, expected_model="claude-sonnet-4-20250514", expected_message=ANTHROPIC_STANDARD_CHAT_COMPLETION_RESPONSE, - expected_usage={"prompt_tokens": 13, "completion_tokens": 39}, + expected_usage={"prompt_tokens": 13, "completion_tokens": 40}, ) assert mock_session.posted_json[0] == { "model": "claude-sonnet-4-20250514", diff --git a/tests/unit_tests/test_google_provider.py b/tests/unit_tests/test_google_provider.py index be152e4..7bc060e 100644 --- a/tests/unit_tests/test_google_provider.py +++ b/tests/unit_tests/test_google_provider.py @@ -121,7 +121,7 @@ async def test_chat_completion_streaming(self): result, expected_model="models/gemini-1.5-pro-latest", expected_message=GOOGLE_STANDARD_CHAT_COMPLETION_RESPONSE, - expected_usage={"prompt_tokens": 6, "completion_tokens": 16}, + expected_usage={"prompt_tokens": 12, "completion_tokens": 16}, ) assert mock_session.posted_json[0] == { "generationConfig": { diff --git a/tests/unit_tests/test_provider_service.py b/tests/unit_tests/test_provider_service.py index 51f8c6d..156b953 100644 --- a/tests/unit_tests/test_provider_service.py +++ b/tests/unit_tests/test_provider_service.py @@ -195,65 +195,76 @@ async def test_load_provider_keys(self): async def test_get_provider_info_explicit_mapping(self): """Test getting provider info with an explicitly mapped model""" # Since keys are already loaded in setUp, _get_provider_info should work directly - provider, model, base_url = self.service._get_provider_info("custom-gpt") + provider, model, base_url, provider_key_id = self.service._get_provider_info("custom-gpt") self.assertEqual(provider, "openai") self.assertEqual(model, "gpt-4") self.assertIsNone(base_url) + self.assertEqual(provider_key_id, self.provider_key_openai.id) - provider, model, base_url = self.service._get_provider_info("test-gemini") + provider, model, base_url, provider_key_id = self.service._get_provider_info("test-gemini") self.assertEqual(provider, "gemini") self.assertEqual(model, "models/gemini-2.0-flash") self.assertIsNone(base_url) + self.assertEqual(provider_key_id, self.provider_key_google.id) async def test_get_provider_info_prefix_matching(self): """Test getting provider info with prefix matching""" # Test OpenAI prefix - provider, model, base_url = self.service._get_provider_info( + provider, model, base_url, provider_key_id = self.service._get_provider_info( "openai/gpt-3.5-turbo" ) self.assertEqual(provider, "openai") + self.assertEqual(provider_key_id, self.provider_key_openai.id) # Test Anthropic prefix - provider, model, base_url = self.service._get_provider_info( + provider, model, base_url, provider_key_id = self.service._get_provider_info( "anthropic/claude-2" ) self.assertEqual(provider, "anthropic") + self.assertEqual(provider_key_id, self.provider_key_anthropic.id) # Test Google prefix - provider, model, base_url = self.service._get_provider_info( + provider, model, base_url, provider_key_id = self.service._get_provider_info( "gemini/models/gemini-2.0-flash" ) self.assertEqual(provider, "gemini") + self.assertEqual(provider_key_id, self.provider_key_google.id) # Test XAI prefix - provider, model, base_url = self.service._get_provider_info("xai/grok-2-1212") + provider, model, base_url, provider_key_id = self.service._get_provider_info("xai/grok-2-1212") self.assertEqual(provider, "xai") + self.assertEqual(provider_key_id, self.provider_key_xai.id) # Test Fireworks prefix - provider, model, base_url = self.service._get_provider_info( + provider, model, base_url, provider_key_id = self.service._get_provider_info( "fireworks/accounts/fireworks/models/code-llama-7b" ) self.assertEqual(provider, "fireworks") + self.assertEqual(provider_key_id, self.provider_key_fireworks.id) # Test OpenRouter prefix - provider, model, base_url = self.service._get_provider_info( + provider, model, base_url, provider_key_id = self.service._get_provider_info( "openrouter/openai/gpt-4o" ) self.assertEqual(provider, "openrouter") + self.assertEqual(provider_key_id, self.provider_key_openrouter.id) # Test Together prefix - provider, model, base_url = self.service._get_provider_info( + provider, model, base_url, provider_key_id = self.service._get_provider_info( "together/WhereIsAI/UAE-Large-V1" ) self.assertEqual(provider, "together") + self.assertEqual(provider_key_id, self.provider_key_together.id) - provider, model, base_url = self.service._get_provider_info("azure/gpt-4o") + provider, model, base_url, provider_key_id = self.service._get_provider_info("azure/gpt-4o") self.assertEqual(provider, "azure") + self.assertEqual(provider_key_id, self.provider_key_azure.id) - provider, model, base_url = self.service._get_provider_info("bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0") + provider, model, base_url, provider_key_id = self.service._get_provider_info("bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0") self.assertEqual(provider, "bedrock") + self.assertEqual(provider_key_id, self.provider_key_bedrock.id) @patch("aiohttp.ClientSession.post") async def test_call_openai_api(self, mock_post): diff --git a/tests/unit_tests/utils/helpers.py b/tests/unit_tests/utils/helpers.py index 4d28272..a84928a 100644 --- a/tests/unit_tests/utils/helpers.py +++ b/tests/unit_tests/utils/helpers.py @@ -139,8 +139,29 @@ def process_openai_streaming_response(response: str, result: dict): result["content"] = result.get("content", "") + content usage = data.get("usage", {}) + result_usage = result.get("usage", { + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + "prompt_tokens_details": { + "cached_tokens": 0, + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + }, + }) if usage: - result["usage"] = usage + prompt_tokens = usage.get("prompt_tokens", 0) + completion_tokens = usage.get("completion_tokens", 0) + total_tokens = usage.get("total_tokens", 0) or (prompt_tokens + completion_tokens) + cached_tokens = usage.get("prompt_tokens_details", {}).get("cached_tokens", 0) + reasoning_tokens = usage.get("completion_tokens_details", {}).get("reasoning_tokens", 0) + result_usage["prompt_tokens"] += prompt_tokens + result_usage["completion_tokens"] += completion_tokens + result_usage["total_tokens"] += total_tokens + result_usage["prompt_tokens_details"]["cached_tokens"] += cached_tokens + result_usage["completion_tokens_details"]["reasoning_tokens"] += reasoning_tokens + result["usage"] = result_usage def validate_chat_completion_streaming_response( @@ -157,3 +178,7 @@ def validate_chat_completion_streaming_response( usage = response["usage"] assert usage["prompt_tokens"] == expected_usage["prompt_tokens"] assert usage["completion_tokens"] == expected_usage["completion_tokens"] + if "prompt_tokens_details" in expected_usage: + expected_usage["prompt_tokens_details"] = usage["prompt_tokens_details"] + if "completion_tokens_details" in expected_usage: + expected_usage["completion_tokens_details"] = usage["completion_tokens_details"]