Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions alembic/versions/9daf34d338f7_update_model_mapping_type_for_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""update model_mapping type for ProviderKey table

Revision ID: 9daf34d338f7
Revises: 08cc005a4bc8
Create Date: 2025-07-18 21:32:48.791253

"""

from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql


# revision identifiers, used by Alembic.
revision = "9daf34d338f7"
down_revision = "08cc005a4bc8"
branch_labels = None
depends_on = None


def upgrade() -> None:
# Change model_mapping column from String to JSON
op.alter_column(
"provider_keys",
"model_mapping",
existing_type=sa.String(),
type_=postgresql.JSON(astext_type=sa.Text()),
existing_nullable=True,
postgresql_using="model_mapping::json",
)


def downgrade() -> None:
# Change model_mapping column from JSON back to String
op.alter_column(
"provider_keys",
"model_mapping",
existing_type=postgresql.JSON(astext_type=sa.Text()),
type_=sa.String(),
existing_nullable=True,
postgresql_using="model_mapping::text",
)
69 changes: 38 additions & 31 deletions app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session, joinedload, selectinload

from app.api.schemas.user import TokenData
from app.core.async_cache import (
Expand All @@ -23,7 +25,7 @@
forge_scope_cache_async,
get_forge_scope_cache_async,
)
from app.core.database import get_db
from app.core.database import get_db, get_async_db
from app.core.logger import get_logger
from app.core.security import (
ALGORITHM,
Expand Down Expand Up @@ -91,7 +93,7 @@ async def fetch_and_cache_jwks() -> list | None:


async def get_current_user(
db: Session = Depends(get_db), token: str = Depends(oauth2_scheme)
db: AsyncSession = Depends(get_async_db), token: str = Depends(oauth2_scheme)
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand All @@ -106,7 +108,9 @@ async def get_current_user(
token_data = TokenData(username=username)
except JWTError as err:
raise credentials_exception from err
user = db.query(User).filter(User.username == token_data.username).first()

result = await db.execute(select(User).filter(User.username == token_data.username))
user = result.scalar_one_or_none()
if user is None:
raise credentials_exception
return user
Expand Down Expand Up @@ -143,7 +147,7 @@ async def get_api_key_from_headers(request: Request) -> str:

async def get_user_by_api_key(
request: Request = None,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
) -> User:
"""Get user by API key from headers, with caching"""
api_key_from_header = await get_api_key_from_headers(request)
Expand Down Expand Up @@ -183,12 +187,12 @@ async def get_user_by_api_key(
# avoids an extra query later in /models.
cached_scope = await get_forge_scope_cache_async(api_key)

api_key_record = (
db.query(ForgeApiKey)
.options(joinedload(ForgeApiKey.allowed_provider_keys))
result = await db.execute(
select(ForgeApiKey)
.options(selectinload(ForgeApiKey.allowed_provider_keys))
.filter(ForgeApiKey.key == api_key_from_header, ForgeApiKey.is_active)
.first()
)
api_key_record = result.scalar_one_or_none()

if not api_key_record:
raise HTTPException(
Expand All @@ -197,12 +201,12 @@ async def get_user_by_api_key(
)

# Get the user associated with this API key and EAGER LOAD all provider keys
user = (
db.query(User)
.options(joinedload(User.provider_keys))
result = await db.execute(
select(User)
.options(selectinload(User.provider_keys))
.filter(User.id == api_key_record.user_id)
.first()
)
user = result.scalar_one_or_none()

if not user:
raise HTTPException(
Expand Down Expand Up @@ -230,7 +234,7 @@ async def get_user_by_api_key(

# Update last used timestamp for the API key
api_key_record.last_used_at = datetime.utcnow()
db.commit()
await db.commit()

# Cache the user data for future requests
await cache_user_async(api_key, user)
Expand Down Expand Up @@ -338,7 +342,7 @@ async def validate_clerk_jwt(token: str = Depends(clerk_token_header)):


async def get_current_user_from_clerk(
db: Session = Depends(get_db), token_payload: dict = Depends(validate_clerk_jwt)
db: AsyncSession = Depends(get_async_db), token_payload: dict = Depends(validate_clerk_jwt)
):
"""Get the current user from Clerk token, creating if needed"""
from urllib.parse import quote
Expand All @@ -352,7 +356,8 @@ async def get_current_user_from_clerk(
)

# Find user by clerk_user_id
user = db.query(User).filter(User.clerk_user_id == clerk_user_id).first()
result = await db.execute(select(User).filter(User.clerk_user_id == clerk_user_id))
user = result.scalar_one_or_none()

# User doesn't exist yet, create one
if not user:
Expand Down Expand Up @@ -398,7 +403,8 @@ async def get_current_user_from_clerk(
username = email

# Check if username exists and make unique if needed
existing_user = db.query(User).filter(User.username == username).first()
result = await db.execute(select(User).filter(User.username == username))
existing_user = result.scalar_one_or_none()
if existing_user:
import random

Expand All @@ -412,20 +418,22 @@ async def get_current_user_from_clerk(
username = clerk_user_id

# Check if user exists with this email
existing_user = db.query(User).filter(User.email == email).first()
result = await db.execute(select(User).filter(User.email == email))
existing_user = result.scalar_one_or_none()
if existing_user:
# Link existing user to Clerk ID
try:
existing_user.clerk_user_id = clerk_user_id
db.commit()
await db.commit()
return existing_user
except IntegrityError:
# Another request might have already linked this user or created a new one
db.rollback()
await db.rollback()
# Retry the query to get the user
user = (
db.query(User).filter(User.clerk_user_id == clerk_user_id).first()
result = await db.execute(
select(User).filter(User.clerk_user_id == clerk_user_id)
)
user = result.scalar_one_or_none()
if user:
return user
# If still no user, continue with creation attempt
Expand All @@ -439,17 +447,15 @@ async def get_current_user_from_clerk(
username=username,
clerk_user_id=clerk_user_id,
is_active=True,
hashed_password=get_password_hash(
"CLERK_AUTH_USER"
), # Add placeholder password for Clerk users
hashed_password="", # Clerk handles authentication
)
db.add(user)
db.commit()
db.refresh(user)
await db.commit()
await db.refresh(user)

# Create default TensorBlock provider for the new user
try:
create_default_tensorblock_provider_for_user(user.id, db)
await create_default_tensorblock_provider_for_user(user.id, db)
except Exception as e:
# Log error but don't fail user creation
logger.warning(
Expand All @@ -459,12 +465,13 @@ async def get_current_user_from_clerk(
return user
except IntegrityError as e:
# Handle race condition: another request might have created the user
db.rollback()
await db.rollback()
if "users_clerk_user_id_key" in str(e) or "clerk_user_id" in str(e):
# Retry the query to get the user that was created by another request
user = (
db.query(User).filter(User.clerk_user_id == clerk_user_id).first()
result = await db.execute(
select(User).filter(User.clerk_user_id == clerk_user_id)
)
user = result.scalar_one_or_none()
if user:
return user
else:
Expand Down
8 changes: 4 additions & 4 deletions app/api/routes/api_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@
import requests
from fastapi import APIRouter, Depends, HTTPException, Request, status
from jose import jwt
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession

from app.api.dependencies import (
get_current_active_user,
get_current_active_user_from_clerk,
)
from app.api.schemas.user import User
from app.core.database import get_db
from app.core.database import get_async_db
from app.models.user import User as UserModel

router = APIRouter()


async def get_user_from_any_auth(
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
jwt_user: UserModel | None = Depends(get_current_active_user),
clerk_user: UserModel | None = Depends(get_current_active_user_from_clerk),
) -> UserModel:
Expand All @@ -45,7 +45,7 @@ async def get_user_from_any_auth(


@router.get("/me", response_model=User)
def get_unified_current_user(
async def get_unified_current_user(
current_user: UserModel = Depends(get_user_from_any_auth),
) -> Any:
"""
Expand Down
Loading
Loading