diff --git a/app/api/routes/statistic.py b/app/api/routes/statistic.py index 0c91f63..0630968 100644 --- a/app/api/routes/statistic.py +++ b/app/api/routes/statistic.py @@ -5,6 +5,7 @@ from sqlalchemy import or_ from datetime import datetime, timedelta, UTC from enum import StrEnum +import decimal from app.api.dependencies import get_async_db, get_current_active_user, get_current_active_user_from_clerk from app.models.user import User @@ -50,6 +51,7 @@ async def get_usage_realtime( ProviderKey.provider_name.label("provider_name"), UsageTracker.model.label("model_name"), (UsageTracker.input_tokens + UsageTracker.output_tokens).label("tokens"), + UsageTracker.cost.label("cost"), func.extract( "epoch", UsageTracker.updated_at - UsageTracker.created_at ).label("duration"), @@ -65,7 +67,8 @@ async def get_usage_realtime( ForgeApiKey.name.ilike(f"%{forge_key}%") ), provider_name is None or ProviderKey.provider_name.ilike(f"%{provider_name}%"), - model_name is None or UsageTracker.model.ilike(f"%{model_name}%") + model_name is None or UsageTracker.model.ilike(f"%{model_name}%"), + UsageTracker.updated_at.is_not(None) ) .order_by(desc(UsageTracker.created_at)) .offset(offset) @@ -86,6 +89,7 @@ async def get_usage_realtime( "provider_name": row.provider_name, "model_name": row.model_name, "tokens": row.tokens, + "cost": decimal.Decimal(row.cost).normalize(), "duration": round(float(row.duration), 2) if row.duration is not None else 0.0, @@ -148,11 +152,13 @@ async def get_usage_summary( func.sum(UsageTracker.input_tokens + UsageTracker.output_tokens).label( "tokens" ), + func.sum(UsageTracker.cost).label("cost"), ) .join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id) .where( UsageTracker.user_id == current_user.id, UsageTracker.created_at >= start_time, + UsageTracker.updated_at.is_not(None) ) .group_by(time_group, ForgeApiKey.name, ForgeApiKey.key) .order_by(time_group, desc("tokens"), "forge_key") @@ -165,17 +171,19 @@ async def get_usage_summary( data_points = dict() for row in rows: if row.time_point not in data_points: - data_points[row.time_point] = {"breakdown": [], "total_tokens": 0} + data_points[row.time_point] = {"breakdown": [], "total_tokens": 0, "total_cost": 0} data_points[row.time_point]["breakdown"].append( - {"forge_key": row.forge_key, "tokens": row.tokens} + {"forge_key": row.forge_key, "tokens": row.tokens, "cost": decimal.Decimal(row.cost).normalize()} ) data_points[row.time_point]["total_tokens"] += row.tokens + data_points[row.time_point]["total_cost"] += decimal.Decimal(row.cost).normalize() return [ UsageSummaryResponse( time_point=time_point, breakdown=data_point["breakdown"], total_tokens=data_point["total_tokens"], + total_cost=data_point["total_cost"], ) for time_point, data_point in data_points.items() ] @@ -223,11 +231,13 @@ async def get_forge_keys_usage( func.sum(UsageTracker.input_tokens + UsageTracker.output_tokens).label( "tokens" ), + func.sum(UsageTracker.cost).label("cost"), ) .join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id) .where( UsageTracker.user_id == current_user.id, start_time is None or UsageTracker.created_at >= start_time, + UsageTracker.updated_at.is_not(None), ) .group_by(ForgeApiKey.name, ForgeApiKey.key) .order_by(desc("tokens"), "forge_key") @@ -237,7 +247,7 @@ async def get_forge_keys_usage( rows = result.fetchall() return [ - ForgeKeysUsageSummaryResponse(forge_key=row.forge_key, tokens=row.tokens) + ForgeKeysUsageSummaryResponse(forge_key=row.forge_key, tokens=row.tokens, cost=decimal.Decimal(row.cost).normalize()) for row in rows ] @@ -248,4 +258,4 @@ async def get_forge_keys_usage_clerk( db: AsyncSession = Depends(get_async_db), span: ForgeKeysUsageTimeSpan = Query(ForgeKeysUsageTimeSpan.week), ): - return await get_forge_keys_usage(current_user, db, span) \ No newline at end of file + return await get_forge_keys_usage(current_user, db, span) diff --git a/app/api/schemas/statistic.py b/app/api/schemas/statistic.py index 5e3f755..62e71d8 100644 --- a/app/api/schemas/statistic.py +++ b/app/api/schemas/statistic.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, field_validator from datetime import datetime import re +import decimal from app.api.schemas.forge_api_key import ForgeApiKeyMasked @@ -18,6 +19,7 @@ class UsageRealtimeResponse(BaseModel): model_name: str tokens: int duration: float + cost: decimal.Decimal @field_validator('forge_key') @classmethod @@ -33,6 +35,7 @@ def convert_timestamp_to_iso(cls, v: datetime) -> str: class UsageSummaryBreakdown(BaseModel): forge_key: str tokens: int + cost: decimal.Decimal @field_validator('forge_key') @classmethod @@ -44,6 +47,7 @@ class UsageSummaryResponse(BaseModel): time_point: datetime breakdown: list[UsageSummaryBreakdown] total_tokens: int + total_cost: decimal.Decimal @field_validator('time_point') @classmethod @@ -54,6 +58,7 @@ def convert_timestamp_to_iso(cls, v: datetime) -> str: class ForgeKeysUsageSummaryResponse(BaseModel): forge_key: str tokens: int + cost: decimal.Decimal @field_validator('forge_key') @classmethod