Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions app/api/routes/claude_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession

from app.api.dependencies import get_user_by_api_key
from app.api.dependencies import get_user_by_api_key, get_user_details_by_api_key
from app.api.routes.proxy import _get_allowed_provider_names
from app.api.schemas.anthropic import (
AnthropicErrorResponse,
Expand Down Expand Up @@ -98,13 +98,16 @@ async def _log_and_return_error_response(
@router.post("/messages", response_model=None, tags=["Claude Code"], status_code=200)
async def create_message_proxy(
request: Request,
user: User = Depends(get_user_by_api_key),
user_details: dict[str, Any] = Depends(get_user_details_by_api_key),
db: AsyncSession = Depends(get_async_db),
) -> Union[JSONResponse, StreamingResponse]:
"""
Main endpoint for Claude Code message completions, proxied through Forge to providers.
Handles request/response conversions, streaming, and dynamic model selection.
"""
user = user_details["user"]
api_key_id = user_details["api_key_id"]

request_id = str(uuid.uuid4())
request.state.request_id = request_id
request.state.start_time_monotonic = time.monotonic()
Expand Down Expand Up @@ -224,7 +227,7 @@ async def create_message_proxy(

try:
# Use Forge's provider service to process the request
provider_service = await ProviderService.async_get_instance(user, db)
provider_service = await ProviderService.async_get_instance(user, db, api_key_id)
allowed_provider_names = await _get_allowed_provider_names(request, db)

# Process request through Forge
Expand Down
102 changes: 85 additions & 17 deletions app/api/routes/statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from enum import StrEnum
import decimal

from app.api.dependencies import get_async_db, get_current_active_user, get_current_active_user_from_clerk
from app.api.dependencies import (
get_async_db,
get_current_active_user,
get_current_active_user_from_clerk,
)
from app.models.user import User
from app.models.usage_tracker import UsageTracker
from app.models.provider_key import ProviderKey
Expand Down Expand Up @@ -40,7 +44,11 @@ async def get_usage_realtime(
# Calculate the date 7 days ago
now = datetime.now(UTC)
seven_days_ago = now - timedelta(days=7)
started_at = started_at if started_at is not None and started_at > seven_days_ago else seven_days_ago
started_at = (
started_at
if started_at is not None and started_at > seven_days_ago
else seven_days_ago
)
ended_at = ended_at if ended_at is not None and ended_at < now else None

# Build the query
Expand All @@ -51,6 +59,11 @@ async def get_usage_realtime(
ProviderKey.provider_name.label("provider_name"),
UsageTracker.model.label("model_name"),
(UsageTracker.input_tokens + UsageTracker.output_tokens).label("tokens"),
(UsageTracker.input_tokens - UsageTracker.cached_tokens).label(
"input_tokens"
),
UsageTracker.output_tokens.label("output_tokens"),
UsageTracker.cached_tokens.label("cached_tokens"),
UsageTracker.cost.label("cost"),
func.extract(
"epoch", UsageTracker.updated_at - UsageTracker.created_at
Expand All @@ -62,13 +75,15 @@ async def get_usage_realtime(
UsageTracker.user_id == current_user.id,
UsageTracker.created_at >= started_at,
ended_at is None or UsageTracker.created_at <= ended_at,
forge_key is None or or_(
forge_key is None
or or_(
ForgeApiKey.key.ilike(f"%{forge_key}%"),
ForgeApiKey.name.ilike(f"%{forge_key}%")
ForgeApiKey.name.ilike(f"%{forge_key}%"),
),
provider_name is None or ProviderKey.provider_name.ilike(f"%{provider_name}%"),
provider_name is None
or ProviderKey.provider_name.ilike(f"%{provider_name}%"),
model_name is None or UsageTracker.model.ilike(f"%{model_name}%"),
UsageTracker.updated_at.is_not(None)
UsageTracker.updated_at.is_not(None),
)
.order_by(desc(UsageTracker.created_at))
.offset(offset)
Expand All @@ -89,16 +104,18 @@ async def get_usage_realtime(
"provider_name": row.provider_name,
"model_name": row.model_name,
"tokens": row.tokens,
"input_tokens": row.input_tokens,
"output_tokens": row.output_tokens,
"cached_tokens": row.cached_tokens,
"cost": decimal.Decimal(row.cost).normalize(),
"duration": round(float(row.duration), 2)
if row.duration is not None
else 0.0,
}
)
print(usage_stats)

return [UsageRealtimeResponse(**usage_stat) for usage_stat in usage_stats]


@router.get("/usage/realtime/clerk", response_model=list[UsageRealtimeResponse])
async def get_usage_realtime_clerk(
current_user: User = Depends(get_current_active_user_from_clerk),
Expand All @@ -111,7 +128,17 @@ async def get_usage_realtime_clerk(
started_at: datetime = Query(None),
ended_at: datetime = Query(None),
):
return await get_usage_realtime(current_user, db, offset, limit, forge_key, provider_name, model_name, started_at, ended_at)
return await get_usage_realtime(
current_user,
db,
offset,
limit,
forge_key,
provider_name,
model_name,
started_at,
ended_at,
)


class UsageSummaryTimeSpan(StrEnum):
Expand Down Expand Up @@ -152,16 +179,21 @@ async def get_usage_summary(
func.sum(UsageTracker.input_tokens + UsageTracker.output_tokens).label(
"tokens"
),
func.sum(UsageTracker.input_tokens - UsageTracker.cached_tokens).label(
"input_tokens"
),
func.sum(UsageTracker.output_tokens).label("output_tokens"),
func.sum(UsageTracker.cached_tokens).label("cached_tokens"),
func.sum(UsageTracker.cost).label("cost"),
)
.join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id)
.where(
UsageTracker.user_id == current_user.id,
UsageTracker.created_at >= start_time,
UsageTracker.updated_at.is_not(None)
UsageTracker.updated_at.is_not(None),
)
.group_by(time_group, ForgeApiKey.name, ForgeApiKey.key)
.order_by(time_group, desc("tokens"), "forge_key")
.order_by(time_group, desc("cost"), "forge_key")
)

# Execute the query
Expand All @@ -171,19 +203,41 @@ async def get_usage_summary(
data_points = dict()
for row in rows:
if row.time_point not in data_points:
data_points[row.time_point] = {"breakdown": [], "total_tokens": 0, "total_cost": 0}
data_points[row.time_point] = {
"breakdown": [],
"total_tokens": 0,
"total_cost": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_cached_tokens": 0,
}
data_points[row.time_point]["breakdown"].append(
{"forge_key": row.forge_key, "tokens": row.tokens, "cost": decimal.Decimal(row.cost).normalize()}
{
"forge_key": row.forge_key,
"tokens": row.tokens,
"cost": decimal.Decimal(row.cost).normalize(),
"input_tokens": row.input_tokens,
"output_tokens": row.output_tokens,
"cached_tokens": row.cached_tokens,
}
)
data_points[row.time_point]["total_tokens"] += row.tokens
data_points[row.time_point]["total_cost"] += decimal.Decimal(row.cost).normalize()
data_points[row.time_point]["total_cost"] += decimal.Decimal(
row.cost
).normalize()
data_points[row.time_point]["total_input_tokens"] += row.input_tokens
data_points[row.time_point]["total_output_tokens"] += row.output_tokens
data_points[row.time_point]["total_cached_tokens"] += row.cached_tokens

return [
UsageSummaryResponse(
time_point=time_point,
breakdown=data_point["breakdown"],
total_tokens=data_point["total_tokens"],
total_cost=data_point["total_cost"],
total_input_tokens=data_point["total_input_tokens"],
total_output_tokens=data_point["total_output_tokens"],
total_cached_tokens=data_point["total_cached_tokens"],
)
for time_point, data_point in data_points.items()
]
Expand Down Expand Up @@ -231,6 +285,11 @@ async def get_forge_keys_usage(
func.sum(UsageTracker.input_tokens + UsageTracker.output_tokens).label(
"tokens"
),
func.sum(UsageTracker.input_tokens - UsageTracker.cached_tokens).label(
"input_tokens"
),
func.sum(UsageTracker.output_tokens).label("output_tokens"),
func.sum(UsageTracker.cached_tokens).label("cached_tokens"),
func.sum(UsageTracker.cost).label("cost"),
)
.join(ForgeApiKey, UsageTracker.forge_key_id == ForgeApiKey.id)
Expand All @@ -240,19 +299,28 @@ async def get_forge_keys_usage(
UsageTracker.updated_at.is_not(None),
)
.group_by(ForgeApiKey.name, ForgeApiKey.key)
.order_by(desc("tokens"), "forge_key")
.order_by(desc("cost"), "forge_key")
)

result = await db.execute(query)
rows = result.fetchall()

return [
ForgeKeysUsageSummaryResponse(forge_key=row.forge_key, tokens=row.tokens, cost=decimal.Decimal(row.cost).normalize())
ForgeKeysUsageSummaryResponse(
forge_key=row.forge_key,
tokens=row.tokens,
cost=decimal.Decimal(row.cost).normalize(),
input_tokens=row.input_tokens,
output_tokens=row.output_tokens,
cached_tokens=row.cached_tokens,
)
for row in rows
]


@router.get("/forge-keys/usage/clerk", response_model=list[ForgeKeysUsageSummaryResponse])
@router.get(
"/forge-keys/usage/clerk", response_model=list[ForgeKeysUsageSummaryResponse]
)
async def get_forge_keys_usage_clerk(
current_user: User = Depends(get_current_active_user_from_clerk),
db: AsyncSession = Depends(get_async_db),
Expand Down
24 changes: 20 additions & 4 deletions app/api/schemas/statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ def mask_forge_name_or_key(v: str) -> str:
return v

class UsageRealtimeResponse(BaseModel):
timestamp: datetime
timestamp: datetime | str
forge_key: str
provider_name: str
model_name: str
tokens: int
input_tokens: int
output_tokens: int
cached_tokens: int
duration: float
cost: decimal.Decimal

Expand All @@ -28,14 +31,19 @@ def mask_forge_key(cls, v: str) -> str:

@field_validator('timestamp')
@classmethod
def convert_timestamp_to_iso(cls, v: datetime) -> str:
def convert_timestamp_to_iso(cls, v: datetime | str) -> str:
if isinstance(v, str):
return v
return v.isoformat()


class UsageSummaryBreakdown(BaseModel):
forge_key: str
tokens: int
cost: decimal.Decimal
input_tokens: int
output_tokens: int
cached_tokens: int

@field_validator('forge_key')
@classmethod
Expand All @@ -44,21 +52,29 @@ def mask_forge_key(cls, v: str) -> str:


class UsageSummaryResponse(BaseModel):
time_point: datetime
time_point: datetime | str
breakdown: list[UsageSummaryBreakdown]
total_tokens: int
total_cost: decimal.Decimal
total_input_tokens: int
total_output_tokens: int
total_cached_tokens: int

@field_validator('time_point')
@classmethod
def convert_timestamp_to_iso(cls, v: datetime) -> str:
def convert_timestamp_to_iso(cls, v: datetime | str) -> str:
if isinstance(v, str):
return v
return v.isoformat()


class ForgeKeysUsageSummaryResponse(BaseModel):
forge_key: str
tokens: int
cost: decimal.Decimal
input_tokens: int
output_tokens: int
cached_tokens: int

@field_validator('forge_key')
@classmethod
Expand Down
14 changes: 7 additions & 7 deletions app/services/pricing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,23 @@ async def _fetch_pricing_with_smart_caching(
exact_pricing = await async_provider_service_cache.get(exact_cache_key)

if exact_pricing and PricingService._is_pricing_valid_for_date(exact_pricing, calculation_date):
logger.debug(f"Hot cache hit for {provider_name}/{model_name}")
logger.debug(f"Exact match cache hit for {provider_name}/{model_name}")
return {**exact_pricing, 'source': 'exact_match'}

# Level 2: Try provider fallback cache (warm cache)
provider_cache_key = f"pricing:provider_fallback:{provider_name}"
provider_cache_key = f"pricing:provider_fallback:{provider_name}:{model_name}"
provider_fallback = await async_provider_service_cache.get(provider_cache_key)

if provider_fallback and PricingService._is_pricing_valid_for_date(provider_fallback, calculation_date):
logger.debug(f"Warm cache hit for provider {provider_name}")
logger.debug(f"Provider fallback cache hit for {provider_name}")
return {**provider_fallback, 'source': 'fallback_provider'}

# Level 3: Try global fallback cache (warm cache)
global_cache_key = f"pricing:global_fallback"
global_cache_key = f"pricing:global_fallback:{provider_name}:{model_name}"
global_fallback = await async_provider_service_cache.get(global_cache_key)

if global_fallback and PricingService._is_pricing_valid_for_date(global_fallback, calculation_date):
logger.debug("Warm cache hit for global fallback")
logger.debug("Global fallback cache hit")
return {**global_fallback, 'source': 'fallback_global'}

# Cache miss - hit database (this should be rare)
Expand Down Expand Up @@ -149,7 +149,7 @@ async def _fetch_from_database_with_caching(

if provider_fallback:
# Cache provider fallback (warm cache)
cache_key = f"pricing:provider_fallback:{provider_name}"
cache_key = f"pricing:provider_fallback:{provider_name}:{model_name}"
await async_provider_service_cache.set(
cache_key, provider_fallback, ttl=PricingService.FALLBACK_CACHE_TTL
)
Expand All @@ -163,7 +163,7 @@ async def _fetch_from_database_with_caching(

if global_fallback:
# Cache global fallback (warm cache)
cache_key = f"pricing:global_fallback"
cache_key = f"pricing:global_fallback:{provider_name}:{model_name}"
await async_provider_service_cache.set(
cache_key, global_fallback, ttl=PricingService.FALLBACK_CACHE_TTL
)
Expand Down
3 changes: 1 addition & 2 deletions app/services/provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,7 @@ async def process_request(
endpoint=endpoint,
)
else:
# TODO: this shouldn't happen, but we handle it gracefully as we don't want to break the flow
# Dive deeper into this if it ever happens
# For api like list models, we don't have usage tracking
logger.info(
f"api_key_id: {self.api_key_id}, provider_key_id: {provider_key_id}"
)
Expand Down
Loading