Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions alembic/versions/4a685a55c5cd_create_usage_tracker_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""create usage tracker table

Revision ID: 4a685a55c5cd
Revises: 9daf34d338f7
Create Date: 2025-08-02 12:29:07.955645

"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import UUID
import uuid


# revision identifiers, used by Alembic.
revision = '4a685a55c5cd'
down_revision = '9daf34d338f7'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table(
"usage_tracker",
sa.Column("id", UUID(as_uuid=True), nullable=False, primary_key=True, default=uuid.uuid4),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("provider_key_id", sa.Integer(), nullable=False),
sa.Column("forge_key_id", sa.Integer(), nullable=False),
sa.Column("model", sa.String(), nullable=True),
sa.Column("endpoint", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("input_tokens", sa.Integer(), nullable=True),
sa.Column("output_tokens", sa.Integer(), nullable=True),
sa.Column("cached_tokens", sa.Integer(), nullable=True),
sa.Column("reasoning_tokens", sa.Integer(), nullable=True),

sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["provider_key_id"], ["provider_keys.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["forge_key_id"], ["forge_api_keys.id"], ondelete="CASCADE"),
)


def downgrade() -> None:
op.drop_table("usage_tracker")
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""enable soft deletion for provider keys and api keys

Revision ID: 831fc2cf16ee
Revises: 4a685a55c5cd
Create Date: 2025-08-02 17:50:12.224293

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '831fc2cf16ee'
down_revision = '4a685a55c5cd'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column('provider_keys', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
op.add_column('forge_api_keys', sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True))
op.alter_column('forge_api_keys', 'key', nullable=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does "key" can be nullable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the user is trying to deleted this record. I would like to wipe out the key value to be None so there is no security issue. We only keep the id/name and the created/deleted timestamp for reference check.

Copy link
Contributor Author

@lingtonglu lingtonglu Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same goes for provider keys, I would wipe out the encrypted_api_key value there. Probably base_url as well?



def downgrade() -> None:
op.drop_column('provider_keys', 'deleted_at')
op.drop_column('forge_api_keys', 'deleted_at')
op.alter_column('forge_api_keys', 'key', nullable=False)
38 changes: 35 additions & 3 deletions app/api/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import contextlib
import json
from typing import Any

# Add environment variables for Clerk
import os
Expand Down Expand Up @@ -147,11 +148,12 @@ async def get_api_key_from_headers(request: Request) -> str:
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key not found in headers",
)


async def get_user_by_api_key(
request: Request = None,
db: AsyncSession = Depends(get_async_db),
include_api_key_id: bool = False,
) -> User:
"""Get user by API key from headers, with caching"""
api_key_from_header = await get_api_key_from_headers(request)
Expand Down Expand Up @@ -185,7 +187,22 @@ async def get_user_by_api_key(
# Return a transient User object from cached data, not a managed one.
# This avoids the db.merge() call and its expensive SELECT query.
# Downstream code can access attributes, but not lazy-load relationships.
return User(**cached_user.model_dump())
if not include_api_key_id:
return User(**cached_user.model_dump())
else:
result = await db.execute(
select(ForgeApiKey)
.options(selectinload(ForgeApiKey.allowed_provider_keys))
.filter(ForgeApiKey.key == api_key_from_header, ForgeApiKey.is_active, ForgeApiKey.deleted_at == None)
)
api_key_record = result.scalar_one_or_none()

if not api_key_record:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return User(**cached_user.model_dump()), api_key_record.id

# Try scope cache first – this doesn't remove the need to verify the key, but it
# avoids an extra query later in /models.
Expand All @@ -194,7 +211,7 @@ async def get_user_by_api_key(
result = await db.execute(
select(ForgeApiKey)
.options(selectinload(ForgeApiKey.allowed_provider_keys))
.filter(ForgeApiKey.key == api_key_from_header, ForgeApiKey.is_active)
.filter(ForgeApiKey.key == api_key_from_header, ForgeApiKey.is_active, ForgeApiKey.deleted_at == None)
)
api_key_record = result.scalar_one_or_none()

Expand Down Expand Up @@ -243,9 +260,24 @@ async def get_user_by_api_key(
# Cache the user data for future requests
await cache_user_async(api_key, user)

if include_api_key_id:
return user, api_key_record.id
return user


async def get_user_details_by_api_key(
request: Request = None,
db: AsyncSession = Depends(get_async_db),
) -> dict[str, Any]:
"""Get user details by API key from headers, with caching"""
user, api_key_id = await get_user_by_api_key(request, db, include_api_key_id=True)

return {
"user": user,
"api_key_id": api_key_id,
}


async def validate_clerk_jwt(token: str = Depends(clerk_token_header)):
"""
Validate a Clerk JWT token using JWKS from Clerk.
Expand Down
26 changes: 19 additions & 7 deletions app/api/routes/api_keys.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any
from datetime import UTC, datetime

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

Expand All @@ -18,7 +19,7 @@
from app.core.async_cache import invalidate_forge_scope_cache_async, invalidate_user_cache_async, invalidate_provider_service_cache_async
from app.core.database import get_async_db
from app.core.security import generate_forge_api_key
from app.models.forge_api_key import ForgeApiKey
from app.models.forge_api_key import ForgeApiKey, forge_api_key_provider_scope_association
from app.models.provider_key import ProviderKey as ProviderKeyModel
from app.models.user import User as UserModel

Expand All @@ -36,7 +37,7 @@ async def _get_api_keys_internal(
result = await db.execute(
select(ForgeApiKey)
.options(selectinload(ForgeApiKey.allowed_provider_keys))
.filter(ForgeApiKey.user_id == current_user.id)
.filter(ForgeApiKey.user_id == current_user.id, ForgeApiKey.deleted_at == None)
)
api_keys = result.scalars().all()

Expand Down Expand Up @@ -71,6 +72,7 @@ async def _create_api_key_internal(
select(ProviderKeyModel).filter(
ProviderKeyModel.id.in_(api_key_create.allowed_provider_key_ids),
ProviderKeyModel.user_id == current_user.id,
ProviderKeyModel.deleted_at == None,
)
)
allowed_providers = result.scalars().all()
Expand Down Expand Up @@ -103,7 +105,7 @@ async def _update_api_key_internal(
result = await db.execute(
select(ForgeApiKey)
.options(selectinload(ForgeApiKey.allowed_provider_keys))
.filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id)
.filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id, ForgeApiKey.deleted_at == None)
)
db_api_key = result.scalar_one_or_none()

Expand All @@ -126,6 +128,7 @@ async def _update_api_key_internal(
select(ProviderKeyModel).filter(
ProviderKeyModel.id.in_(api_key_update.allowed_provider_key_ids),
ProviderKeyModel.user_id == current_user.id,
ProviderKeyModel.deleted_at == None,
)
)
allowed_providers = result.scalars().all()
Expand Down Expand Up @@ -161,7 +164,7 @@ async def _delete_api_key_internal(
result = await db.execute(
select(ForgeApiKey)
.options(selectinload(ForgeApiKey.allowed_provider_keys))
.filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id)
.filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id, ForgeApiKey.deleted_at == None)
)
db_api_key = result.scalar_one_or_none()

Expand All @@ -180,7 +183,16 @@ async def _delete_api_key_internal(
"allowed_provider_key_ids": [pk.id for pk in db_api_key.allowed_provider_keys],
}

await db.delete(db_api_key)
# do soft deletion here. Set the deleted_at column to the current time
db_api_key.deleted_at = datetime.now(UTC)

# Delete the record from forge_api_key_provider_scope_association where forge_api_key_id matches current id
await db.execute(
delete(forge_api_key_provider_scope_association).where(
forge_api_key_provider_scope_association.c.forge_api_key_id == db_api_key.id
)
)

await db.commit()

await invalidate_user_cache_async(key_to_invalidate)
Expand All @@ -198,7 +210,7 @@ async def _regenerate_api_key_internal(
result = await db.execute(
select(ForgeApiKey)
.options(selectinload(ForgeApiKey.allowed_provider_keys))
.filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id)
.filter(ForgeApiKey.id == key_id, ForgeApiKey.user_id == current_user.id, ForgeApiKey.deleted_at == None)
)
db_api_key = result.scalar_one_or_none()

Expand Down
38 changes: 28 additions & 10 deletions app/api/routes/provider_keys.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
from datetime import UTC, datetime
from typing import Any

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from starlette import status

from app.api.dependencies import (
get_current_active_user,
Expand All @@ -16,10 +15,11 @@
ProviderKeyUpdate,
ProviderKeyUpsertItem,
)
from app.core.async_cache import invalidate_provider_service_cache_async
from app.core.async_cache import invalidate_provider_service_cache_async, invalidate_forge_scope_cache_async
from app.core.database import get_async_db
from app.core.logger import get_logger
from app.core.security import decrypt_api_key, encrypt_api_key
from app.models.forge_api_key import forge_api_key_provider_scope_association
from app.models.provider_key import ProviderKey as ProviderKeyModel
from app.models.user import User as UserModel
from app.services.providers.adapter_factory import ProviderAdapterFactory
Expand Down Expand Up @@ -56,7 +56,7 @@ async def _get_provider_keys_internal(
Internal logic to get all provider keys for the current user.
"""
result = await db.execute(
select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id)
select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id, ProviderKeyModel.deleted_at == None)
)
provider_keys = result.scalars().all()
return [ProviderKey.model_validate(pk) for pk in provider_keys]
Expand Down Expand Up @@ -94,6 +94,7 @@ async def _create_provider_key_internal(
select(ProviderKeyModel).filter(
ProviderKeyModel.user_id == current_user.id,
ProviderKeyModel.provider_name == provider_key_create.provider_name,
ProviderKeyModel.deleted_at == None,
)
)
existing_key = result.scalar_one_or_none()
Expand Down Expand Up @@ -148,6 +149,7 @@ async def _update_provider_key_internal(
select(ProviderKeyModel).filter(
ProviderKeyModel.provider_name == provider_name,
ProviderKeyModel.user_id == current_user.id,
ProviderKeyModel.deleted_at == None,
)
)
db_provider_key = result.scalar_one_or_none()
Expand Down Expand Up @@ -175,6 +177,7 @@ async def _process_provider_key_delete_data(
select(ProviderKeyModel).filter(
ProviderKeyModel.provider_name == provider_name,
ProviderKeyModel.user_id == user_id,
ProviderKeyModel.deleted_at == None,
)
)
db_provider_key = result.scalar_one_or_none()
Expand All @@ -185,9 +188,18 @@ async def _process_provider_key_delete_data(
# Store the provider key data before deletion
provider_key_data = ProviderKey.model_validate(db_provider_key)

await db.delete(db_provider_key)
# do soft deletion here. Set the deleted_at column to the current time
db_provider_key.deleted_at = datetime.now(UTC)

# Delete the record from forge_api_key_provider_scope_association where provider_key_id matches current id
await db.execute(
delete(forge_api_key_provider_scope_association).where(
forge_api_key_provider_scope_association.c.provider_key_id == db_provider_key.id,
)
)
scoped_forge_api_keys = db_provider_key.scoped_forge_api_keys

return provider_key_data
return provider_key_data, [scoped_forge_api_key.key for scoped_forge_api_key in scoped_forge_api_keys]


async def _delete_provider_key_internal(
Expand All @@ -196,11 +208,13 @@ async def _delete_provider_key_internal(
"""
Internal logic to delete a provider key for the current user.
"""
provider_key_data = await _process_provider_key_delete_data(db, provider_name, current_user.id)
provider_key_data, scoped_forge_api_keys = await _process_provider_key_delete_data(db, provider_name, current_user.id)
await db.commit()

# Invalidate caches after deleting a provider key
await invalidate_provider_service_cache_async(current_user.id)
for scoped_forge_api_key in scoped_forge_api_keys:
await invalidate_forge_scope_cache_async(scoped_forge_api_key)

return provider_key_data

Expand Down Expand Up @@ -302,13 +316,14 @@ async def _batch_upsert_provider_keys_internal(

# 1. Fetch all existing keys for the user
result = await db.execute(
select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id)
select(ProviderKeyModel).filter(ProviderKeyModel.user_id == current_user.id, ProviderKeyModel.deleted_at == None)
)
existing_keys_query = result.scalars().all()
# 2. Map them by provider_name for efficient lookup
existing_keys_map: dict[str, ProviderKeyModel] = {
key.provider_name: key for key in existing_keys_query
}
invalidated_forge_api_keys = set()

for item in items:
if "****" in item.api_key:
Expand All @@ -320,7 +335,8 @@ async def _batch_upsert_provider_keys_internal(
# Handle deletion if api_key is "DELETE"
if item.api_key == "DELETE":
if existing_provider_key:
await _process_provider_key_delete_data(db, item.provider_name, current_user.id)
_, scoped_forge_api_keys = await _process_provider_key_delete_data(db, item.provider_name, current_user.id)
invalidated_forge_api_keys.update(scoped_forge_api_keys)
processed = True
elif existing_provider_key: # Update existing key
db_key_to_process = await _process_provider_key_update_data(existing_provider_key, ProviderKeyUpdate.model_validate(item.model_dump(exclude_unset=True)))
Expand Down Expand Up @@ -357,6 +373,8 @@ async def _batch_upsert_provider_keys_internal(
await db.refresh(key) # Refresh each key to get DB-generated values like id, timestamps
processed_keys = [ProviderKey.model_validate(key) for key in processed_keys]
await invalidate_provider_service_cache_async(current_user.id)
for key in invalidated_forge_api_keys:
await invalidate_forge_scope_cache_async(key)
except Exception as e:
await db.rollback()
error_message_prefix = "Error during final commit/refresh in batch upsert"
Expand Down
Loading
Loading