diff --git a/app/api/routes/claude_code.py b/app/api/routes/claude_code.py index 6503544..3ac7596 100644 --- a/app/api/routes/claude_code.py +++ b/app/api/routes/claude_code.py @@ -14,7 +14,7 @@ from pydantic import ValidationError from sqlalchemy.ext.asyncio import AsyncSession -from app.api.dependencies import get_user_by_api_key +from app.api.dependencies import get_user_by_api_key, get_user_details_by_api_key from app.api.routes.proxy import _get_allowed_provider_names from app.api.schemas.anthropic import ( AnthropicErrorResponse, @@ -98,13 +98,16 @@ async def _log_and_return_error_response( @router.post("/messages", response_model=None, tags=["Claude Code"], status_code=200) async def create_message_proxy( request: Request, - user: User = Depends(get_user_by_api_key), + user_details: dict[str, Any] = Depends(get_user_details_by_api_key), db: AsyncSession = Depends(get_async_db), ) -> Union[JSONResponse, StreamingResponse]: """ Main endpoint for Claude Code message completions, proxied through Forge to providers. Handles request/response conversions, streaming, and dynamic model selection. """ + user = user_details["user"] + api_key_id = user_details["api_key_id"] + request_id = str(uuid.uuid4()) request.state.request_id = request_id request.state.start_time_monotonic = time.monotonic() @@ -224,7 +227,7 @@ async def create_message_proxy( try: # Use Forge's provider service to process the request - provider_service = await ProviderService.async_get_instance(user, db) + provider_service = await ProviderService.async_get_instance(user, db, api_key_id) allowed_provider_names = await _get_allowed_provider_names(request, db) # Process request through Forge diff --git a/app/api/routes/statistic.py b/app/api/routes/statistic.py index 0630968..c225e98 100644 --- a/app/api/routes/statistic.py +++ b/app/api/routes/statistic.py @@ -7,7 +7,11 @@ from enum import StrEnum import decimal -from app.api.dependencies import get_async_db, get_current_active_user, get_current_active_user_from_clerk +from app.api.dependencies import ( + get_async_db, + get_current_active_user, + get_current_active_user_from_clerk, +) from app.models.user import User from app.models.usage_tracker import UsageTracker from app.models.provider_key import ProviderKey @@ -40,7 +44,11 @@ async def get_usage_realtime( # Calculate the date 7 days ago now = datetime.now(UTC) seven_days_ago = now - timedelta(days=7) - started_at = started_at if started_at is not None and started_at > seven_days_ago else seven_days_ago + started_at = ( + started_at + if started_at is not None and started_at > seven_days_ago + else seven_days_ago + ) ended_at = ended_at if ended_at is not None and ended_at < now else None # Build the query @@ -51,6 +59,11 @@ async def get_usage_realtime( ProviderKey.provider_name.label("provider_name"), UsageTracker.model.label("model_name"), (UsageTracker.input_tokens + UsageTracker.output_tokens).label("tokens"), + (UsageTracker.input_tokens - UsageTracker.cached_tokens).label( + "input_tokens" + ), + UsageTracker.output_tokens.label("output_tokens"), + UsageTracker.cached_tokens.label("cached_tokens"), UsageTracker.cost.label("cost"), func.extract( "epoch", UsageTracker.updated_at - UsageTracker.created_at @@ -62,13 +75,15 @@ async def get_usage_realtime( UsageTracker.user_id == current_user.id, UsageTracker.created_at >= started_at, ended_at is None or UsageTracker.created_at <= ended_at, - forge_key is None or or_( + forge_key is None + or or_( ForgeApiKey.key.ilike(f"%{forge_key}%"), - ForgeApiKey.name.ilike(f"%{forge_key}%") + ForgeApiKey.name.ilike(f"%{forge_key}%"), ), - provider_name is None or ProviderKey.provider_name.ilike(f"%{provider_name}%"), + provider_name is None + or ProviderKey.provider_name.ilike(f"%{provider_name}%"), model_name is None or UsageTracker.model.ilike(f"%{model_name}%"), - UsageTracker.updated_at.is_not(None) + UsageTracker.updated_at.is_not(None), ) .order_by(desc(UsageTracker.created_at)) .offset(offset) @@ -89,16 +104,18 @@ async def get_usage_realtime( "provider_name": row.provider_name, "model_name": row.model_name, "tokens": row.tokens, + "input_tokens": row.input_tokens, + "output_tokens": row.output_tokens, + "cached_tokens": row.cached_tokens, "cost": decimal.Decimal(row.cost).normalize(), "duration": round(float(row.duration), 2) if row.duration is not None else 0.0, } ) - print(usage_stats) - return [UsageRealtimeResponse(**usage_stat) for usage_stat in usage_stats] + @router.get("/usage/realtime/clerk", response_model=list[UsageRealtimeResponse]) async def get_usage_realtime_clerk( current_user: User = Depends(get_current_active_user_from_clerk), @@ -111,7 +128,17 @@ async def get_usage_realtime_clerk( started_at: datetime = Query(None), ended_at: datetime = Query(None), ): - return await get_usage_realtime(current_user, db, offset, limit, forge_key, provider_name, model_name, started_at, ended_at) + return await get_usage_realtime( + current_user, + db, + offset, + limit, + forge_key, + provider_name, + model_name, + started_at, + ended_at, + ) class UsageSummaryTimeSpan(StrEnum): @@ -152,16 +179,21 @@ async def get_usage_summary( func.sum(UsageTracker.input_tokens + UsageTracker.output_tokens).label( "tokens" ), + func.sum(UsageTracker.input_tokens - UsageTracker.cached_tokens).label( + "input_tokens" + ), + func.sum(UsageTracker.output_tokens).label("output_tokens"), + func.sum(UsageTracker.cached_tokens).label("cached_tokens"), func.sum(UsageTracker.cost).label("cost"), ) .join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id) .where( UsageTracker.user_id == current_user.id, UsageTracker.created_at >= start_time, - UsageTracker.updated_at.is_not(None) + UsageTracker.updated_at.is_not(None), ) .group_by(time_group, ForgeApiKey.name, ForgeApiKey.key) - .order_by(time_group, desc("tokens"), "forge_key") + .order_by(time_group, desc("cost"), "forge_key") ) # Execute the query @@ -171,12 +203,31 @@ async def get_usage_summary( data_points = dict() for row in rows: if row.time_point not in data_points: - data_points[row.time_point] = {"breakdown": [], "total_tokens": 0, "total_cost": 0} + data_points[row.time_point] = { + "breakdown": [], + "total_tokens": 0, + "total_cost": 0, + "total_input_tokens": 0, + "total_output_tokens": 0, + "total_cached_tokens": 0, + } data_points[row.time_point]["breakdown"].append( - {"forge_key": row.forge_key, "tokens": row.tokens, "cost": decimal.Decimal(row.cost).normalize()} + { + "forge_key": row.forge_key, + "tokens": row.tokens, + "cost": decimal.Decimal(row.cost).normalize(), + "input_tokens": row.input_tokens, + "output_tokens": row.output_tokens, + "cached_tokens": row.cached_tokens, + } ) data_points[row.time_point]["total_tokens"] += row.tokens - data_points[row.time_point]["total_cost"] += decimal.Decimal(row.cost).normalize() + data_points[row.time_point]["total_cost"] += decimal.Decimal( + row.cost + ).normalize() + data_points[row.time_point]["total_input_tokens"] += row.input_tokens + data_points[row.time_point]["total_output_tokens"] += row.output_tokens + data_points[row.time_point]["total_cached_tokens"] += row.cached_tokens return [ UsageSummaryResponse( @@ -184,6 +235,9 @@ async def get_usage_summary( breakdown=data_point["breakdown"], total_tokens=data_point["total_tokens"], total_cost=data_point["total_cost"], + total_input_tokens=data_point["total_input_tokens"], + total_output_tokens=data_point["total_output_tokens"], + total_cached_tokens=data_point["total_cached_tokens"], ) for time_point, data_point in data_points.items() ] @@ -231,6 +285,11 @@ async def get_forge_keys_usage( func.sum(UsageTracker.input_tokens + UsageTracker.output_tokens).label( "tokens" ), + func.sum(UsageTracker.input_tokens - UsageTracker.cached_tokens).label( + "input_tokens" + ), + func.sum(UsageTracker.output_tokens).label("output_tokens"), + func.sum(UsageTracker.cached_tokens).label("cached_tokens"), func.sum(UsageTracker.cost).label("cost"), ) .join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id) @@ -240,19 +299,28 @@ async def get_forge_keys_usage( UsageTracker.updated_at.is_not(None), ) .group_by(ForgeApiKey.name, ForgeApiKey.key) - .order_by(desc("tokens"), "forge_key") + .order_by(desc("cost"), "forge_key") ) result = await db.execute(query) rows = result.fetchall() return [ - ForgeKeysUsageSummaryResponse(forge_key=row.forge_key, tokens=row.tokens, cost=decimal.Decimal(row.cost).normalize()) + ForgeKeysUsageSummaryResponse( + forge_key=row.forge_key, + tokens=row.tokens, + cost=decimal.Decimal(row.cost).normalize(), + input_tokens=row.input_tokens, + output_tokens=row.output_tokens, + cached_tokens=row.cached_tokens, + ) for row in rows ] -@router.get("/forge-keys/usage/clerk", response_model=list[ForgeKeysUsageSummaryResponse]) +@router.get( + "/forge-keys/usage/clerk", response_model=list[ForgeKeysUsageSummaryResponse] +) async def get_forge_keys_usage_clerk( current_user: User = Depends(get_current_active_user_from_clerk), db: AsyncSession = Depends(get_async_db), diff --git a/app/api/schemas/statistic.py b/app/api/schemas/statistic.py index 62e71d8..c39d733 100644 --- a/app/api/schemas/statistic.py +++ b/app/api/schemas/statistic.py @@ -13,11 +13,14 @@ def mask_forge_name_or_key(v: str) -> str: return v class UsageRealtimeResponse(BaseModel): - timestamp: datetime + timestamp: datetime | str forge_key: str provider_name: str model_name: str tokens: int + input_tokens: int + output_tokens: int + cached_tokens: int duration: float cost: decimal.Decimal @@ -28,7 +31,9 @@ def mask_forge_key(cls, v: str) -> str: @field_validator('timestamp') @classmethod - def convert_timestamp_to_iso(cls, v: datetime) -> str: + def convert_timestamp_to_iso(cls, v: datetime | str) -> str: + if isinstance(v, str): + return v return v.isoformat() @@ -36,6 +41,9 @@ class UsageSummaryBreakdown(BaseModel): forge_key: str tokens: int cost: decimal.Decimal + input_tokens: int + output_tokens: int + cached_tokens: int @field_validator('forge_key') @classmethod @@ -44,14 +52,19 @@ def mask_forge_key(cls, v: str) -> str: class UsageSummaryResponse(BaseModel): - time_point: datetime + time_point: datetime | str breakdown: list[UsageSummaryBreakdown] total_tokens: int total_cost: decimal.Decimal + total_input_tokens: int + total_output_tokens: int + total_cached_tokens: int @field_validator('time_point') @classmethod - def convert_timestamp_to_iso(cls, v: datetime) -> str: + def convert_timestamp_to_iso(cls, v: datetime | str) -> str: + if isinstance(v, str): + return v return v.isoformat() @@ -59,6 +72,9 @@ class ForgeKeysUsageSummaryResponse(BaseModel): forge_key: str tokens: int cost: decimal.Decimal + input_tokens: int + output_tokens: int + cached_tokens: int @field_validator('forge_key') @classmethod diff --git a/app/services/pricing_service.py b/app/services/pricing_service.py index e06d7ed..66bf611 100644 --- a/app/services/pricing_service.py +++ b/app/services/pricing_service.py @@ -93,23 +93,23 @@ async def _fetch_pricing_with_smart_caching( exact_pricing = await async_provider_service_cache.get(exact_cache_key) if exact_pricing and PricingService._is_pricing_valid_for_date(exact_pricing, calculation_date): - logger.debug(f"Hot cache hit for {provider_name}/{model_name}") + logger.debug(f"Exact match cache hit for {provider_name}/{model_name}") return {**exact_pricing, 'source': 'exact_match'} # Level 2: Try provider fallback cache (warm cache) - provider_cache_key = f"pricing:provider_fallback:{provider_name}" + provider_cache_key = f"pricing:provider_fallback:{provider_name}:{model_name}" provider_fallback = await async_provider_service_cache.get(provider_cache_key) if provider_fallback and PricingService._is_pricing_valid_for_date(provider_fallback, calculation_date): - logger.debug(f"Warm cache hit for provider {provider_name}") + logger.debug(f"Provider fallback cache hit for {provider_name}") return {**provider_fallback, 'source': 'fallback_provider'} # Level 3: Try global fallback cache (warm cache) - global_cache_key = f"pricing:global_fallback" + global_cache_key = f"pricing:global_fallback:{provider_name}:{model_name}" global_fallback = await async_provider_service_cache.get(global_cache_key) if global_fallback and PricingService._is_pricing_valid_for_date(global_fallback, calculation_date): - logger.debug("Warm cache hit for global fallback") + logger.debug("Global fallback cache hit") return {**global_fallback, 'source': 'fallback_global'} # Cache miss - hit database (this should be rare) @@ -149,7 +149,7 @@ async def _fetch_from_database_with_caching( if provider_fallback: # Cache provider fallback (warm cache) - cache_key = f"pricing:provider_fallback:{provider_name}" + cache_key = f"pricing:provider_fallback:{provider_name}:{model_name}" await async_provider_service_cache.set( cache_key, provider_fallback, ttl=PricingService.FALLBACK_CACHE_TTL ) @@ -163,7 +163,7 @@ async def _fetch_from_database_with_caching( if global_fallback: # Cache global fallback (warm cache) - cache_key = f"pricing:global_fallback" + cache_key = f"pricing:global_fallback:{provider_name}:{model_name}" await async_provider_service_cache.set( cache_key, global_fallback, ttl=PricingService.FALLBACK_CACHE_TTL ) diff --git a/app/services/provider_service.py b/app/services/provider_service.py index 7df1f37..b85c243 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -598,8 +598,7 @@ async def process_request( endpoint=endpoint, ) else: - # TODO: this shouldn't happen, but we handle it gracefully as we don't want to break the flow - # Dive deeper into this if it ever happens + # For api like list models, we don't have usage tracking logger.info( f"api_key_id: {self.api_key_id}, provider_key_id: {provider_key_id}" )