From b2d0ba961d6d1d55386dac8ddd58968fc8ba8934 Mon Sep 17 00:00:00 2001 From: Alan Chen Date: Wed, 11 Feb 2026 09:41:59 -0500 Subject: [PATCH 1/6] feat(server): add usage limits, plans, and code redemption system --- config.yaml | 9 + scripts/migrations/002_add_user_plan.sql | 56 ++++ src/config/settings.py | 9 + src/server/app/chat.py | 17 +- src/server/app/plans.py | 37 +++ src/server/app/setup.py | 13 + src/server/app/usage.py | 97 +++++++ src/server/app/users.py | 29 ++- src/server/app/workspaces.py | 3 +- src/server/database/redemption.py | 133 ++++++++++ src/server/database/user.py | 17 +- src/server/dependencies/__init__.py | 0 src/server/dependencies/usage_limits.py | 100 +++++++ src/server/models/user.py | 29 +++ src/server/services/plan_service.py | 135 ++++++++++ src/server/services/usage_limiter.py | 315 +++++++++++++++++++++++ 16 files changed, 983 insertions(+), 16 deletions(-) create mode 100644 scripts/migrations/002_add_user_plan.sql create mode 100644 src/server/app/plans.py create mode 100644 src/server/app/usage.py create mode 100644 src/server/database/redemption.py create mode 100644 src/server/dependencies/__init__.py create mode 100644 src/server/dependencies/usage_limits.py create mode 100644 src/server/services/plan_service.py create mode 100644 src/server/services/usage_limiter.py diff --git a/config.yaml b/config.yaml index 44ab0d69..4ae67001 100644 --- a/config.yaml +++ b/config.yaml @@ -98,3 +98,12 @@ redis: warm_after_invalidation: true # Pre-populate cache after invalidation +# ============================================================================= +# USAGE LIMITS CONFIGURATION +# ============================================================================= +# Per-user, tier-based usage limits. Only enforced when auth is enabled. +# Complete no-op in local dev (when SUPABASE_URL is unset). +usage_limits: + enabled: true + plan_cache_ttl: 300 # seconds — PlanService in-memory cache TTL + burst_counter_ttl: 300 # seconds (5 minutes, short-lived burst window) diff --git a/scripts/migrations/002_add_user_plan.sql b/scripts/migrations/002_add_user_plan.sql new file mode 100644 index 00000000..beedaa42 --- /dev/null +++ b/scripts/migrations/002_add_user_plan.sql @@ -0,0 +1,56 @@ +-- Migration 002: Plans table + user plan_id FK + redemption codes system +-- Purpose: Move tier/plan definitions from config.yaml to DB for dynamic management + +-- 1. Plans table (source of truth for tiers) +CREATE TABLE IF NOT EXISTS plans ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) UNIQUE NOT NULL, + display_name VARCHAR(100) NOT NULL, + rank INT NOT NULL DEFAULT 0, + daily_credits NUMERIC(10,2) NOT NULL DEFAULT 500.0, + max_active_workspaces INT NOT NULL DEFAULT 3, + max_concurrent_requests INT NOT NULL DEFAULT 5, + is_default BOOLEAN NOT NULL DEFAULT FALSE, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() +); +CREATE UNIQUE INDEX IF NOT EXISTS idx_plans_default ON plans (is_default) WHERE is_default = TRUE; +CREATE UNIQUE INDEX IF NOT EXISTS idx_plans_rank ON plans (rank); + +-- 2. Seed initial plans +INSERT INTO plans (name, display_name, rank, daily_credits, max_active_workspaces, max_concurrent_requests, is_default) VALUES + ('free', 'Free', 0, 1000.0, 3, 5, TRUE), + ('pro', 'Pro', 1, 5000.0, 10, 20, FALSE), + ('enterprise', 'Enterprise', 2, -1, -1, -1, FALSE) +ON CONFLICT (name) DO NOTHING; + +-- 3. Add plan_id FK to users +ALTER TABLE users ADD COLUMN IF NOT EXISTS plan_id INT; +UPDATE users SET plan_id = (SELECT id FROM plans WHERE is_default = TRUE LIMIT 1) WHERE plan_id IS NULL; +ALTER TABLE users ALTER COLUMN plan_id SET NOT NULL; +ALTER TABLE users ALTER COLUMN plan_id SET DEFAULT 1; +ALTER TABLE users ADD CONSTRAINT fk_users_plan FOREIGN KEY (plan_id) REFERENCES plans(id); +CREATE INDEX IF NOT EXISTS idx_users_plan_id ON users (plan_id); + +-- 4. Redemption codes +CREATE TABLE IF NOT EXISTS redemption_codes ( + code VARCHAR(50) PRIMARY KEY, + plan_id INT NOT NULL REFERENCES plans(id), + max_redemptions INT NOT NULL DEFAULT 1, + current_redemptions INT NOT NULL DEFAULT 0, + expires_at TIMESTAMPTZ, + created_at TIMESTAMPTZ DEFAULT NOW(), + is_active BOOLEAN NOT NULL DEFAULT TRUE +); + +-- 5. Redemption history (plan names as strings for audit trail) +CREATE TABLE IF NOT EXISTS redemption_history ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + code VARCHAR(50) NOT NULL REFERENCES redemption_codes(code), + user_id VARCHAR(255) NOT NULL REFERENCES users(user_id) ON UPDATE CASCADE, + previous_plan VARCHAR(50) NOT NULL, + new_plan VARCHAR(50) NOT NULL, + redeemed_at TIMESTAMPTZ DEFAULT NOW(), + UNIQUE(code, user_id) +); +CREATE INDEX IF NOT EXISTS idx_redemption_history_user ON redemption_history(user_id); diff --git a/src/config/settings.py b/src/config/settings.py index 9ab1eb4d..81516d26 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -411,6 +411,15 @@ def get_search_api() -> str: +# ============================================================================= +# Usage Limits Configuration +# ============================================================================= + +def get_usage_limits_config() -> Dict[str, Any]: + """Get the full usage_limits section from config.yaml.""" + return get_config('usage_limits', {}) + + # ============================================================================= # Redis Configuration # ============================================================================= diff --git a/src/server/app/chat.py b/src/server/app/chat.py index 79c47eac..44fbd947 100644 --- a/src/server/app/chat.py +++ b/src/server/app/chat.py @@ -70,6 +70,8 @@ ) from src.server.utils.image_context import parse_image_contexts, inject_image_context from src.server.utils.api import CurrentUserId +from src.server.dependencies.usage_limits import ChatRateLimited +from src.server.services.usage_limiter import UsageLimiter # Locale/timezone configuration from src.config.settings import ( @@ -88,7 +90,7 @@ @router.post("/stream") -async def chat_stream(request: ChatRequest, user_id: CurrentUserId): +async def chat_stream(request: ChatRequest, user_id: ChatRateLimited): """ Stream PTC agent responses as Server-Sent Events. @@ -424,6 +426,10 @@ async def _astream_flash_workflow( raise + finally: + # Release burst slot for flash workflows (synchronous, no background task) + await UsageLimiter.release_burst_slot(user_id) + async def _astream_workflow( request: ChatRequest, @@ -902,6 +908,9 @@ async def on_background_workflow_complete(): f"[PTC_CHAT] Background completion persistence failed for {thread_id}: {e}", exc_info=True, ) + finally: + # Release burst slot so it doesn't block future requests + await UsageLimiter.release_burst_slot(user_id) # Start workflow in background with event buffering task_info = await manager.start_workflow( @@ -1015,6 +1024,9 @@ async def on_background_workflow_complete(): # Cancel the background workflow await manager.cancel_workflow(thread_id) + # Release burst slot on cancellation + await UsageLimiter.release_burst_slot(user_id) + registry_store = BackgroundRegistryStore.get_instance() await registry_store.cancel_and_clear(thread_id, force=True) else: @@ -1047,6 +1059,9 @@ async def on_background_workflow_complete(): # Phase 4: Error Recovery with Retry Logic # ===================================================================== + # Release burst slot on error so it doesn't block future requests + await UsageLimiter.release_burst_slot(user_id) + # Get token/tool usage for billing even on errors _per_call_records = token_callback.per_call_records if token_callback else None _tool_usage = handler.get_tool_usage() if handler else None diff --git a/src/server/app/plans.py b/src/server/app/plans.py new file mode 100644 index 00000000..c97519b1 --- /dev/null +++ b/src/server/app/plans.py @@ -0,0 +1,37 @@ +""" +Plans API Router. + +Public endpoint — no auth required. +Frontend uses this to display available plans. + +Endpoints: +- GET /api/v1/plans — List all plans +""" + +from fastapi import APIRouter + +from src.server.services.plan_service import PlanService + +router = APIRouter(prefix="/api/v1/plans", tags=["Plans"]) + + +@router.get("") +async def list_plans(): + """Return all plans ordered by rank.""" + svc = PlanService.get_instance() + await svc.ensure_loaded() + plans = svc.get_all_plans() + return { + "plans": [ + { + "id": p.id, + "name": p.name, + "display_name": p.display_name, + "rank": p.rank, + "daily_credits": p.daily_credits, + "max_active_workspaces": p.max_active_workspaces, + "max_concurrent_requests": p.max_concurrent_requests, + } + for p in plans + ] + } diff --git a/src/server/app/setup.py b/src/server/app/setup.py index cbcbbcf8..5fe3837d 100644 --- a/src/server/app/setup.py +++ b/src/server/app/setup.py @@ -114,6 +114,15 @@ async def lifespan(app: FastAPI): logger.warning(f"Redis cache initialization failed: {e}") logger.warning("Server will continue without caching") + # Initialize PlanService (load plans from DB into memory) + try: + from src.server.services.plan_service import PlanService + plan_svc = PlanService.get_instance() + await plan_svc.refresh() + except Exception as e: + logger.warning(f"PlanService initialization failed: {e}") + logger.warning("Using fallback plan definitions") + # Start BackgroundTaskManager cleanup task try: manager = BackgroundTaskManager.get_instance() @@ -324,6 +333,8 @@ async def send_wrapper(message): from src.server.app.portfolio import router as portfolio_router from src.server.app.infoflow import router as infoflow_router from src.server.app.sec_proxy import router as sec_proxy_router +from src.server.app.usage import router as usage_router +from src.server.app.plans import router as plans_router # Include all routers app.include_router(chat_router) # /api/v1/chat/* - Main chat endpoint @@ -341,4 +352,6 @@ async def send_wrapper(message): app.include_router(portfolio_router) # /api/v1/users/me/portfolio/* - Portfolio management app.include_router(infoflow_router) # /api/v1/infoflow/* - InfoFlow content feed app.include_router(sec_proxy_router) # /api/v1/sec-proxy/* - SEC EDGAR document proxy +app.include_router(usage_router) # /api/v1/usage/* - Usage limits and code redemption +app.include_router(plans_router) # /api/v1/plans - Plan definitions (public) app.include_router(health_router) # /health - Health check diff --git a/src/server/app/usage.py b/src/server/app/usage.py new file mode 100644 index 00000000..9cba3a93 --- /dev/null +++ b/src/server/app/usage.py @@ -0,0 +1,97 @@ +""" +Usage Limits and Code Redemption API Router. + +Endpoints: +- GET /api/v1/usage — Current usage status (credits, workspaces, plan) +- POST /api/v1/usage/redeem — Redeem a code to upgrade plan +""" + +import logging + +from fastapi import APIRouter, HTTPException + +from src.server.utils.api import CurrentUserId +from src.server.services.usage_limiter import UsageLimiter +from src.server.database.redemption import redeem_code +from src.server.models.user import RedeemCodeRequest, RedeemCodeResponse +from src.server.services.plan_service import PlanService + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api/v1/usage", tags=["Usage"]) + + +@router.get("") +async def get_usage_status(user_id: CurrentUserId): + """ + Get current usage status for the authenticated user. + + Returns plan info, credit usage, and workspace usage. + When limits are disabled, returns limits_enabled=false. + """ + svc = PlanService.get_instance() + await svc.ensure_loaded() + + def _plan_obj(plan_info): + return { + 'id': plan_info.id, + 'name': plan_info.name, + 'display_name': plan_info.display_name, + 'rank': plan_info.rank, + } + + if not UsageLimiter.is_enabled(): + return { + 'limits_enabled': False, + 'plan': _plan_obj(svc.get_default_plan()), + 'credits': {'used': 0.0, 'limit': -1, 'remaining': -1}, + 'workspaces': {'active': 0, 'limit': -1, 'remaining': -1}, + } + + plan = await UsageLimiter.get_user_plan(user_id) + daily_credit_limit = plan.daily_credits + workspace_limit = plan.max_active_workspaces + + used_credits = await UsageLimiter.get_daily_credit_usage(user_id) + active_workspaces = await UsageLimiter.get_active_workspace_count(user_id) + + credits_remaining = max(0.0, daily_credit_limit - used_credits) if daily_credit_limit != -1 else -1 + workspace_remaining = max(0, workspace_limit - active_workspaces) if workspace_limit != -1 else -1 + + return { + 'limits_enabled': True, + 'plan': _plan_obj(plan), + 'credits': { + 'used': round(used_credits, 2), + 'limit': daily_credit_limit, + 'remaining': round(credits_remaining, 2) if credits_remaining != -1 else -1, + }, + 'workspaces': { + 'active': active_workspaces, + 'limit': workspace_limit, + 'remaining': workspace_remaining, + }, + } + + +@router.post("/redeem", response_model=RedeemCodeResponse) +async def redeem_usage_code(request: RedeemCodeRequest, user_id: CurrentUserId): + """ + Redeem a code to upgrade the user's plan. + + The code is validated and applied in a single database transaction. + On success, the Redis plan cache is flushed so new limits take effect immediately. + """ + try: + result = await redeem_code(user_id, request.code) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + # Flush Redis plan cache so new limits take effect immediately + await UsageLimiter.flush_plan_cache(user_id) + + return RedeemCodeResponse( + previous_plan=result['previous_plan'], + new_plan=result['new_plan'], + message=f"Plan upgraded to {result['new_plan']}", + ) diff --git a/src/server/app/users.py b/src/server/app/users.py index 4743385f..c28a3e3d 100644 --- a/src/server/app/users.py +++ b/src/server/app/users.py @@ -41,12 +41,29 @@ UserUpdate, UserWithPreferencesResponse, ) +from src.server.services.plan_service import PlanService from src.server.utils.api import CurrentUserId, handle_api_exceptions, raise_not_found logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/v1", tags=["Users"]) + +async def enrich_user_with_plan(user_dict: dict) -> dict: + """Replace raw plan_id with a plan object for UserResponse serialization.""" + svc = PlanService.get_instance() + await svc.ensure_loaded() + plan_id = user_dict.pop('plan_id', None) + plan = svc.get_plan(plan_id) if plan_id else svc.get_default_plan() + user_dict['plan'] = { + 'id': plan.id, + 'name': plan.name, + 'display_name': plan.display_name, + 'rank': plan.rank, + } + return user_dict + + # ==================== Auth Sync ==================== @@ -81,7 +98,7 @@ async def sync_user( result = await get_user_with_preferences(user_id) if not result: raise_not_found("User") - user_resp = UserResponse.model_validate(result["user"]) + user_resp = UserResponse.model_validate(await enrich_user_with_plan(result["user"])) pref_resp = None if result.get("preferences"): pref_resp = UserPreferencesResponse.model_validate(result["preferences"]) @@ -97,7 +114,7 @@ async def sync_user( result = await get_user_with_preferences(user_id) if not result: raise_not_found("User") - user_resp = UserResponse.model_validate(result["user"]) + user_resp = UserResponse.model_validate(await enrich_user_with_plan(result["user"])) pref_resp = None if result.get("preferences"): pref_resp = UserPreferencesResponse.model_validate(result["preferences"]) @@ -110,7 +127,7 @@ async def sync_user( name=body.name, avatar_url=body.avatar_url, ) - user_resp = UserResponse.model_validate(user) + user_resp = UserResponse.model_validate(await enrich_user_with_plan(user)) return UserWithPreferencesResponse(user=user_resp, preferences=None) @@ -148,7 +165,7 @@ async def create_user( ) logger.info(f"Created user {user_id}") - return UserResponse.model_validate(user) + return UserResponse.model_validate(await enrich_user_with_plan(user)) @router.get("/users/me", response_model=UserWithPreferencesResponse) @@ -173,7 +190,7 @@ async def get_current_user(user_id: CurrentUserId): if not result: raise_not_found("User") - user_response = UserResponse.model_validate(result["user"]) + user_response = UserResponse.model_validate(await enrich_user_with_plan(result["user"])) preferences_response = None if result["preferences"]: preferences_response = UserPreferencesResponse.model_validate(result["preferences"]) @@ -227,7 +244,7 @@ async def update_current_user( # Get preferences for combined response preferences = await db_get_user_preferences(user_id) - user_response = UserResponse.model_validate(user) + user_response = UserResponse.model_validate(await enrich_user_with_plan(user)) preferences_response = None if preferences: preferences_response = UserPreferencesResponse.model_validate(preferences) diff --git a/src/server/app/workspaces.py b/src/server/app/workspaces.py index fb0bae4d..0ca67f98 100644 --- a/src/server/app/workspaces.py +++ b/src/server/app/workspaces.py @@ -20,6 +20,7 @@ from fastapi import APIRouter, HTTPException, Query from src.server.utils.api import CurrentUserId +from src.server.dependencies.usage_limits import WorkspaceLimitCheck from src.server.database.workspace import ( get_workspace as db_get_workspace, get_workspaces_for_user, @@ -68,7 +69,7 @@ def _workspace_to_response(workspace: dict) -> WorkspaceResponse: @router.post("", response_model=WorkspaceResponse, status_code=201) async def create_workspace( request: WorkspaceCreate, - x_user_id: CurrentUserId, + x_user_id: WorkspaceLimitCheck, ): """ Create a new workspace with dedicated sandbox. diff --git a/src/server/database/redemption.py b/src/server/database/redemption.py new file mode 100644 index 00000000..8a24326d --- /dev/null +++ b/src/server/database/redemption.py @@ -0,0 +1,133 @@ +""" +Database functions for code redemption. + +Provides a single atomic function that validates and redeems a code +in one database transaction. +""" + +import logging +from typing import Any, Dict + +from psycopg.rows import dict_row + +from src.server.database.conversation import get_db_connection +from src.server.services.plan_service import PlanService + +logger = logging.getLogger(__name__) + + +async def redeem_code(user_id: str, code: str) -> Dict[str, Any]: + """ + Validate and redeem a code in a single transaction. + + Checks: + 1. Code exists and is_active + 2. Code not expired + 3. Code not exhausted (current_redemptions < max_redemptions, or max=-1) + 4. User hasn't already redeemed this code + 5. Code tier is >= user's current tier (no downgrade) + + On success: updates users.plan_id, increments current_redemptions, + inserts redemption_history row (with plan names for audit). + + Args: + user_id: The user redeeming the code + code: The redemption code (case-insensitive, will be uppercased) + + Returns: + {"previous_plan": "free", "new_plan": "pro", "code": "PROMO123"} + + Raises: + ValueError: With specific message on any validation failure + """ + code = code.strip().upper() + + svc = PlanService.get_instance() + await svc.ensure_loaded() + + async with get_db_connection() as conn: + # Run everything in a single transaction + async with conn.transaction(): + async with conn.cursor(row_factory=dict_row) as cur: + # 1. Look up the code + await cur.execute( + "SELECT * FROM redemption_codes WHERE code = %s FOR UPDATE", + (code,), + ) + code_row = await cur.fetchone() + + if not code_row: + raise ValueError("Invalid code") + + if not code_row['is_active']: + raise ValueError("Code is no longer active") + + # 2. Check expiry + if code_row['expires_at'] is not None: + from datetime import datetime, timezone as tz + now = datetime.now(tz.utc) + if now > code_row['expires_at']: + raise ValueError("Code has expired") + + # 3. Check exhaustion + if code_row['max_redemptions'] != -1: + if code_row['current_redemptions'] >= code_row['max_redemptions']: + raise ValueError("Code has been fully redeemed") + + # 4. Check double-redeem + await cur.execute( + "SELECT id FROM redemption_history WHERE code = %s AND user_id = %s", + (code, user_id), + ) + if await cur.fetchone(): + raise ValueError("You have already redeemed this code") + + # 5. Get user's current plan_id + await cur.execute( + "SELECT plan_id FROM users WHERE user_id = %s FOR UPDATE", + (user_id,), + ) + user_row = await cur.fetchone() + if not user_row: + raise ValueError("User not found") + + current_plan_id = user_row['plan_id'] + target_plan_id = code_row['plan_id'] + + # Resolve to PlanInfo for rank comparison and name display + current_plan = svc.get_plan(current_plan_id) + target_plan = svc.get_plan(target_plan_id) + + # No downgrade check + if target_plan.rank <= current_plan.rank: + if target_plan_id == current_plan_id: + raise ValueError(f"You are already on the {current_plan.display_name} plan") + raise ValueError(f"Cannot downgrade from {current_plan.display_name} to {target_plan.display_name}") + + # All checks passed — apply the upgrade + await cur.execute( + "UPDATE users SET plan_id = %s, updated_at = NOW() WHERE user_id = %s", + (target_plan_id, user_id), + ) + + await cur.execute( + "UPDATE redemption_codes SET current_redemptions = current_redemptions + 1 WHERE code = %s", + (code,), + ) + + # Audit trail uses plan names (strings) for readability + await cur.execute(""" + INSERT INTO redemption_history (code, user_id, previous_plan, new_plan) + VALUES (%s, %s, %s, %s) + """, (code, user_id, current_plan.name, target_plan.name)) + + logger.info( + f"[redemption] User {user_id} redeemed code {code}: " + f"{current_plan.name} -> {target_plan.name}" + ) + + return { + 'previous_plan': current_plan.name, + 'new_plan': target_plan.name, + 'code': code, + } diff --git a/src/server/database/user.py b/src/server/database/user.py index a4f14901..b5a51cc8 100644 --- a/src/server/database/user.py +++ b/src/server/database/user.py @@ -67,7 +67,7 @@ async def create_user( VALUES (%s, %s, %s, %s, %s, %s, FALSE, NOW(), NOW()) RETURNING user_id, email, name, avatar_url, timezone, locale, - onboarding_completed, created_at, updated_at, last_login_at + onboarding_completed, plan_id, created_at, updated_at, last_login_at """, (user_id, email, name, avatar_url, timezone, locale)) result = await cur.fetchone() @@ -82,7 +82,7 @@ async def find_user_by_email(email: str) -> Optional[Dict[str, Any]]: await cur.execute(""" SELECT user_id, email, name, avatar_url, timezone, locale, - onboarding_completed, created_at, updated_at, last_login_at + onboarding_completed, plan_id, created_at, updated_at, last_login_at FROM users WHERE email = %s LIMIT 1 @@ -104,7 +104,7 @@ async def migrate_user_id(old_user_id: str, new_user_id: str) -> Optional[Dict[s WHERE user_id = %s RETURNING user_id, email, name, avatar_url, timezone, locale, - onboarding_completed, created_at, updated_at, last_login_at + onboarding_completed, plan_id, created_at, updated_at, last_login_at """, (new_user_id, old_user_id)) result = await cur.fetchone() if result: @@ -139,7 +139,7 @@ async def create_user_from_auth( updated_at = NOW() RETURNING user_id, email, name, avatar_url, timezone, locale, - onboarding_completed, created_at, updated_at, last_login_at + onboarding_completed, plan_id, created_at, updated_at, last_login_at """, (user_id, email, name, avatar_url)) result = await cur.fetchone() logger.info(f"[user_db] create_user_from_auth user_id={user_id}") @@ -161,7 +161,7 @@ async def get_user(user_id: str) -> Optional[Dict[str, Any]]: await cur.execute(""" SELECT user_id, email, name, avatar_url, timezone, locale, - onboarding_completed, created_at, updated_at, last_login_at + onboarding_completed, plan_id, created_at, updated_at, last_login_at FROM users WHERE user_id = %s """, (user_id,)) @@ -212,7 +212,7 @@ async def update_user( returning_columns = [ "user_id", "email", "name", "avatar_url", "timezone", "locale", - "onboarding_completed", "created_at", "updated_at", "last_login_at", + "onboarding_completed", "plan_id", "created_at", "updated_at", "last_login_at", ] query, params = builder.build( @@ -274,7 +274,7 @@ async def upsert_user( updated_at = NOW() RETURNING user_id, email, name, avatar_url, timezone, locale, - onboarding_completed, created_at, updated_at, last_login_at + onboarding_completed, plan_id, created_at, updated_at, last_login_at """, (user_id, email, name, avatar_url, timezone, locale)) result = await cur.fetchone() @@ -457,7 +457,7 @@ async def get_user_with_preferences(user_id: str) -> Optional[Dict[str, Any]]: await cur.execute(""" SELECT u.user_id, u.email, u.name, u.avatar_url, u.timezone, u.locale, - u.onboarding_completed, u.created_at, u.updated_at, u.last_login_at, + u.onboarding_completed, u.plan_id, u.created_at, u.updated_at, u.last_login_at, p.preference_id, p.risk_preference, p.investment_preference, p.agent_preference, p.other_preference, p.created_at as pref_created_at, p.updated_at as pref_updated_at @@ -479,6 +479,7 @@ async def get_user_with_preferences(user_id: str) -> Optional[Dict[str, Any]]: 'timezone': result['timezone'], 'locale': result['locale'], 'onboarding_completed': result['onboarding_completed'], + 'plan_id': result['plan_id'], 'created_at': result['created_at'], 'updated_at': result['updated_at'], 'last_login_at': result['last_login_at'], diff --git a/src/server/dependencies/__init__.py b/src/server/dependencies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/server/dependencies/usage_limits.py b/src/server/dependencies/usage_limits.py new file mode 100644 index 00000000..e937742e --- /dev/null +++ b/src/server/dependencies/usage_limits.py @@ -0,0 +1,100 @@ +""" +FastAPI dependencies for usage limit enforcement. + +Provides two dependencies that compose with get_current_user_id: +- ChatRateLimited: Enforces daily credit limit + burst guard +- WorkspaceLimitCheck: Enforces active workspace limits + +Both are complete no-ops when auth is disabled. +""" + +from typing import Annotated + +from fastapi import Depends, HTTPException + +from src.server.utils.api import get_current_user_id +from src.server.services.usage_limiter import UsageLimiter + + +async def enforce_chat_limit( + user_id: str = Depends(get_current_user_id), +) -> str: + """ + FastAPI dependency: enforce daily credit limit + burst guard. + + Layer 1: DB credit check (SUM total_credits today vs tier daily_credits) + Layer 2: Redis burst guard (concurrent in-flight request cap) + + Returns user_id on success, raises HTTPException(429) if over limit. + """ + if not UsageLimiter.is_enabled(): + return user_id + + result = await UsageLimiter.check_chat_limit(user_id) + + if not result['allowed']: + # Determine which limit was hit for the message + is_credit_limit = result['remaining_credits'] == 0.0 and result['credit_limit'] != -1 + if is_credit_limit: + message = 'Daily credit limit reached' + limit_type = 'credit_limit' + else: + message = 'Too many concurrent requests, please wait' + limit_type = 'burst_limit' + + raise HTTPException( + status_code=429, + detail={ + 'message': message, + 'type': limit_type, + 'used_credits': result['used_credits'], + 'credit_limit': result['credit_limit'], + 'remaining_credits': result['remaining_credits'], + 'retry_after': result['retry_after'], + }, + headers={ + 'Retry-After': str(result['retry_after'] or 30), + 'X-RateLimit-Limit': str(result['credit_limit']), + 'X-RateLimit-Remaining': str(result['remaining_credits']), + }, + ) + + return user_id + + +async def enforce_workspace_limit( + user_id: str = Depends(get_current_user_id), +) -> str: + """ + FastAPI dependency: enforce active workspace limit. + + Queries DB for active workspace count. + Returns user_id on success, raises HTTPException(429) if at limit. + """ + if not UsageLimiter.is_enabled(): + return user_id + + result = await UsageLimiter.check_workspace_limit(user_id) + + if not result['allowed']: + raise HTTPException( + status_code=429, + detail={ + 'message': 'Active workspace limit reached', + 'type': 'workspace_limit', + 'current': result['current'], + 'limit': result['limit'], + 'remaining': result['remaining'], + }, + headers={ + 'X-RateLimit-Limit': str(result['limit']), + 'X-RateLimit-Remaining': '0', + }, + ) + + return user_id + + +# Annotated types for cleaner endpoint signatures +ChatRateLimited = Annotated[str, Depends(enforce_chat_limit)] +WorkspaceLimitCheck = Annotated[str, Depends(enforce_workspace_limit)] diff --git a/src/server/models/user.py b/src/server/models/user.py index b0ca3832..3b81ebe5 100644 --- a/src/server/models/user.py +++ b/src/server/models/user.py @@ -201,6 +201,15 @@ class UserUpdate(UserBase): ) +class PlanResponse(BaseModel): + """Nested plan object returned inside UserResponse.""" + + id: int = Field(description="Plan ID") + name: str = Field(description="Plan internal name") + display_name: str = Field(description="Plan display name") + rank: int = Field(description="Plan rank (0 = lowest)") + + class UserResponse(UserBase): """Response model for user details.""" @@ -208,6 +217,7 @@ class UserResponse(UserBase): onboarding_completed: bool = Field( default=False, description="Whether onboarding is completed" ) + plan: PlanResponse = Field(description="User plan details") created_at: datetime = Field(description="Creation timestamp") updated_at: datetime = Field(description="Last update timestamp") last_login_at: Optional[datetime] = Field(None, description="Last login timestamp") @@ -216,6 +226,25 @@ class Config: from_attributes = True +# ============================================================================= +# Code Redemption Models +# ============================================================================= + + +class RedeemCodeRequest(BaseModel): + """Request model for redeeming a code.""" + + code: str = Field(..., min_length=1, max_length=50, description="Redemption code") + + +class RedeemCodeResponse(BaseModel): + """Response model for successful code redemption.""" + + previous_plan: str = Field(description="Plan before redemption") + new_plan: str = Field(description="Plan after redemption") + message: str = Field(description="Human-readable success message") + + # ============================================================================= # User Preferences Models # ============================================================================= diff --git a/src/server/services/plan_service.py b/src/server/services/plan_service.py new file mode 100644 index 00000000..f69698cb --- /dev/null +++ b/src/server/services/plan_service.py @@ -0,0 +1,135 @@ +""" +PlanService — singleton that caches plan definitions from the `plans` DB table. + +Usage: + svc = PlanService.get_instance() + await svc.ensure_loaded() # async — call once (or periodically) + plan = svc.get_plan(plan_id) # sync dict lookup after that +""" + +import logging +import time +from dataclasses import dataclass +from typing import Dict, List, Optional + +from src.config.settings import get_usage_limits_config + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class PlanInfo: + id: int + name: str + display_name: str + rank: int + daily_credits: float + max_active_workspaces: int + max_concurrent_requests: int + is_default: bool + + +# Hardcoded fallback if DB is unreachable on first boot +_FALLBACK_PLAN = PlanInfo( + id=1, + name="free", + display_name="Free", + rank=0, + daily_credits=50.0, + max_active_workspaces=3, + max_concurrent_requests=5, + is_default=True, +) + + +class PlanService: + _instance: Optional["PlanService"] = None + + def __init__(self) -> None: + self._plans_by_id: Dict[int, PlanInfo] = {} + self._plans_by_name: Dict[str, PlanInfo] = {} + self._default_plan: Optional[PlanInfo] = None + self._loaded_at: float = 0.0 + + # ── singleton ──────────────────────────────────────────────── + @classmethod + def get_instance(cls) -> "PlanService": + if cls._instance is None: + cls._instance = cls() + return cls._instance + + # ── async DB loading ───────────────────────────────────────── + async def refresh(self) -> None: + """Load all plans from DB into memory.""" + try: + from src.server.database.conversation import get_db_connection + from psycopg.rows import dict_row + + async with get_db_connection() as conn: + async with conn.cursor(row_factory=dict_row) as cur: + await cur.execute( + "SELECT id, name, display_name, rank, daily_credits, " + "max_active_workspaces, max_concurrent_requests, is_default " + "FROM plans ORDER BY rank" + ) + rows = await cur.fetchall() + + by_id: Dict[int, PlanInfo] = {} + by_name: Dict[str, PlanInfo] = {} + default: Optional[PlanInfo] = None + + for row in rows: + plan = PlanInfo( + id=row["id"], + name=row["name"], + display_name=row["display_name"], + rank=row["rank"], + daily_credits=float(row["daily_credits"]), + max_active_workspaces=row["max_active_workspaces"], + max_concurrent_requests=row["max_concurrent_requests"], + is_default=row["is_default"], + ) + by_id[plan.id] = plan + by_name[plan.name] = plan + if plan.is_default: + default = plan + + self._plans_by_id = by_id + self._plans_by_name = by_name + self._default_plan = default or (list(by_id.values())[0] if by_id else _FALLBACK_PLAN) + self._loaded_at = time.monotonic() + logger.info(f"[PlanService] Loaded {len(by_id)} plans") + + except Exception as e: + logger.warning(f"[PlanService] DB load failed, using fallback: {e}") + if not self._plans_by_id: + self._plans_by_id = {_FALLBACK_PLAN.id: _FALLBACK_PLAN} + self._plans_by_name = {_FALLBACK_PLAN.name: _FALLBACK_PLAN} + self._default_plan = _FALLBACK_PLAN + self._loaded_at = time.monotonic() + + async def ensure_loaded(self) -> None: + """Refresh if not yet loaded or TTL expired.""" + config = get_usage_limits_config() + ttl = config.get("plan_cache_ttl", 300) + if not self._plans_by_id or (time.monotonic() - self._loaded_at) > ttl: + await self.refresh() + + # ── sync lookups (pure dict reads) ─────────────────────────── + def get_plan(self, plan_id: Optional[int]) -> PlanInfo: + if plan_id is not None and plan_id in self._plans_by_id: + return self._plans_by_id[plan_id] + return self.get_default_plan() + + def get_plan_by_name(self, name: str) -> Optional[PlanInfo]: + return self._plans_by_name.get(name) + + def get_default_plan(self) -> PlanInfo: + return self._default_plan or _FALLBACK_PLAN + + def get_all_plans(self) -> List[PlanInfo]: + return sorted(self._plans_by_id.values(), key=lambda p: p.rank) + + def get_rank(self, plan_id: int) -> int: + plan = self._plans_by_id.get(plan_id) + return plan.rank if plan else -1 diff --git a/src/server/services/usage_limiter.py b/src/server/services/usage_limiter.py new file mode 100644 index 00000000..5024e6f3 --- /dev/null +++ b/src/server/services/usage_limiter.py @@ -0,0 +1,315 @@ +""" +Usage limiter service for per-user, tier-based credit limiting. + +Two enforcement layers: +1. **Credit limit** (DB): SUM(total_credits) from conversation_usage today. + This is the real limit. Checked before each request. +2. **Burst guard** (Redis): INCR counter to cap concurrent in-flight requests. + Prevents a user from firing many requests before any completes and writes credits. + +Workspace limit uses a simple DB COUNT — unchanged. + +Graceful degradation: if Redis or DB is down, requests are allowed (not blocked). +Complete no-op when auth is disabled. +""" + +import logging +from datetime import datetime, timedelta, timezone + +from src.server.auth.jwt_bearer import _AUTH_ENABLED +from src.config.settings import get_usage_limits_config +from src.server.services.plan_service import PlanService, PlanInfo + +logger = logging.getLogger(__name__) + + +class UsageLimiter: + """Static methods for usage limit checking.""" + + @staticmethod + def is_enabled() -> bool: + """Check if usage limits are active (auth enabled + config enabled).""" + if not _AUTH_ENABLED: + return False + config = get_usage_limits_config() + return bool(config.get('enabled', False)) + + # ===================================================================== + # Plan lookup (Redis-cached) + # ===================================================================== + + @staticmethod + async def get_user_plan(user_id: str) -> PlanInfo: + """ + Get user's PlanInfo with Redis caching (caches plan_id). + + Falls back to DB lookup on cache miss. Returns default plan on any error. + """ + svc = PlanService.get_instance() + await svc.ensure_loaded() + default = svc.get_default_plan() + + if not UsageLimiter.is_enabled(): + return default + + cache_key = f"user:plan:{user_id}" + config = get_usage_limits_config() + cache_ttl = config.get('plan_cache_ttl', 300) + + # Try Redis cache first (stores plan_id as int) + try: + from src.utils.cache.redis_cache import get_cache_client + cache = get_cache_client() + if cache.client: + cached = await cache.client.get(cache_key) + if cached: + plan_id = int(cached.decode() if isinstance(cached, bytes) else cached) + return svc.get_plan(plan_id) + except Exception as e: + logger.debug(f"[usage_limiter] Redis cache read failed: {e}") + + # Cache miss — query DB for plan_id + try: + from src.server.database.user import get_user + user = await get_user(user_id) + plan_id = user.get('plan_id') if user else None + except Exception as e: + logger.warning(f"[usage_limiter] DB lookup failed for {user_id}: {e}") + return default + + plan = svc.get_plan(plan_id) + + # Write back to cache (store plan_id) + try: + from src.utils.cache.redis_cache import get_cache_client + cache = get_cache_client() + if cache.client: + await cache.client.set(cache_key, str(plan.id), ex=cache_ttl) + except Exception as e: + logger.debug(f"[usage_limiter] Redis cache write failed: {e}") + + return plan + + @staticmethod + async def flush_plan_cache(user_id: str) -> None: + """Delete the cached plan for a user (call after plan upgrade).""" + try: + from src.utils.cache.redis_cache import get_cache_client + cache = get_cache_client() + if cache.client: + await cache.client.delete(f"user:plan:{user_id}") + except Exception as e: + logger.debug(f"[usage_limiter] Failed to flush plan cache: {e}") + + # ===================================================================== + # Credit-based chat limit + # ===================================================================== + + @staticmethod + async def check_chat_limit(user_id: str) -> dict: + """ + Check daily credit limit + burst guard. + + Two layers: + 1. DB credit check: SUM(total_credits) for today vs daily_credits tier limit + 2. Redis burst guard: INCR counter vs max_concurrent_requests + + Returns: + {allowed, used_credits, credit_limit, remaining_credits, retry_after, + burst_count, burst_limit} + """ + if not UsageLimiter.is_enabled(): + return { + 'allowed': True, 'used_credits': 0.0, 'credit_limit': -1, + 'remaining_credits': -1, 'retry_after': None, + 'burst_count': 0, 'burst_limit': -1, + } + + plan = await UsageLimiter.get_user_plan(user_id) + daily_credit_limit = plan.daily_credits + max_concurrent = plan.max_concurrent_requests + + # --- Layer 1: DB credit check --- + if daily_credit_limit != -1: + used_credits = await UsageLimiter.get_daily_credit_usage(user_id) + if used_credits >= daily_credit_limit: + now = datetime.now(timezone.utc) + next_midnight = now.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta(days=1) + retry_after = int((next_midnight - now).total_seconds()) + return { + 'allowed': False, + 'used_credits': used_credits, + 'credit_limit': daily_credit_limit, + 'remaining_credits': 0.0, + 'retry_after': retry_after, + 'burst_count': 0, + 'burst_limit': max_concurrent, + } + else: + used_credits = 0.0 # Don't bother querying for unlimited + + # --- Layer 2: Redis burst guard --- + if max_concurrent != -1: + burst_result = await UsageLimiter._check_burst_guard(user_id, max_concurrent) + if not burst_result['allowed']: + return { + 'allowed': False, + 'used_credits': used_credits, + 'credit_limit': daily_credit_limit, + 'remaining_credits': max(0.0, daily_credit_limit - used_credits) if daily_credit_limit != -1 else -1, + 'retry_after': 30, # Short retry for burst + 'burst_count': burst_result['count'], + 'burst_limit': max_concurrent, + } + burst_count = burst_result['count'] + else: + burst_count = 0 + + remaining = max(0.0, daily_credit_limit - used_credits) if daily_credit_limit != -1 else -1 + + return { + 'allowed': True, + 'used_credits': used_credits, + 'credit_limit': daily_credit_limit, + 'remaining_credits': remaining, + 'retry_after': None, + 'burst_count': burst_count, + 'burst_limit': max_concurrent, + } + + @staticmethod + async def release_burst_slot(user_id: str) -> None: + """Decrement the burst counter after a request completes or is rejected.""" + try: + from src.utils.cache.redis_cache import get_cache_client + cache = get_cache_client() + if cache.client: + counter_key = f"usage:burst:{user_id}" + val = await cache.client.decr(counter_key) + # Don't let it go negative + if val is not None and int(val) < 0: + await cache.client.set(counter_key, 0, keepttl=True) + except Exception as e: + logger.debug(f"[usage_limiter] Failed to release burst slot: {e}") + + @staticmethod + async def _check_burst_guard(user_id: str, max_concurrent: int) -> dict: + """ + Redis-based burst guard. Prevents too many concurrent in-flight requests. + + Returns: + {allowed: bool, count: int} + """ + config = get_usage_limits_config() + burst_ttl = config.get('burst_counter_ttl', 300) + counter_key = f"usage:burst:{user_id}" + + try: + from src.utils.cache.redis_cache import get_cache_client + cache = get_cache_client() + if not cache.client: + return {'allowed': True, 'count': 0} + + pipe = cache.client.pipeline() + pipe.incr(counter_key) + pipe.expire(counter_key, burst_ttl) + results = await pipe.execute() + current = results[0] + + if current > max_concurrent: + # Over burst limit — roll back + await cache.client.decr(counter_key) + return {'allowed': False, 'count': current - 1} + + return {'allowed': True, 'count': current} + + except Exception as e: + logger.warning(f"[usage_limiter] Redis burst guard failed: {e}") + return {'allowed': True, 'count': 0} + + # ===================================================================== + # Credit usage from DB (source of truth) + # ===================================================================== + + @staticmethod + async def get_daily_credit_usage(user_id: str) -> float: + """ + Get today's total credits consumed from conversation_usage (DB truth). + + Used for both limit checking and reporting. + """ + try: + from src.server.database.conversation import get_db_connection + from psycopg.rows import dict_row + async with get_db_connection() as conn: + async with conn.cursor(row_factory=dict_row) as cur: + await cur.execute( + """ + SELECT COALESCE(SUM(total_credits), 0) as total + FROM conversation_usage + WHERE user_id = %s + AND timestamp >= (CURRENT_DATE AT TIME ZONE 'UTC') + """, + (user_id,), + ) + result = await cur.fetchone() + return float(result['total']) if result else 0.0 + except Exception as e: + logger.warning(f"[usage_limiter] DB credit query failed for {user_id}: {e}") + return 0.0 + + # ===================================================================== + # Workspace limit (count-based, unchanged) + # ===================================================================== + + @staticmethod + async def check_workspace_limit(user_id: str) -> dict: + """ + Check if user can create another workspace. + + Returns: + {allowed: bool, current: int, limit: int, remaining: int} + """ + if not UsageLimiter.is_enabled(): + return {'allowed': True, 'current': 0, 'limit': -1, 'remaining': -1} + + plan = await UsageLimiter.get_user_plan(user_id) + max_workspaces = plan.max_active_workspaces + + if max_workspaces == -1: + return {'allowed': True, 'current': 0, 'limit': -1, 'remaining': -1} + + active_count = await UsageLimiter.get_active_workspace_count(user_id) + + if active_count >= max_workspaces: + return { + 'allowed': False, + 'current': active_count, + 'limit': max_workspaces, + 'remaining': 0, + } + + return { + 'allowed': True, + 'current': active_count, + 'limit': max_workspaces, + 'remaining': max(0, max_workspaces - active_count), + } + + @staticmethod + async def get_active_workspace_count(user_id: str) -> int: + """Count active workspaces for a user (creating or running status).""" + try: + from src.server.database.conversation import get_db_connection + from psycopg.rows import dict_row + async with get_db_connection() as conn: + async with conn.cursor(row_factory=dict_row) as cur: + await cur.execute( + "SELECT COUNT(*) as cnt FROM workspaces WHERE user_id = %s AND status IN ('creating', 'running')", + (user_id,), + ) + result = await cur.fetchone() + return result['cnt'] if result else 0 + except Exception as e: + logger.warning(f"[usage_limiter] Failed to count workspaces for {user_id}: {e}") + return 0 From c5089682c6bc1337b11678f5115b2eddb73f4ca8 Mon Sep 17 00:00:00 2001 From: Alan Chen Date: Wed, 11 Feb 2026 09:42:05 -0500 Subject: [PATCH 2/6] feat(web): add plan display, usage status, and code redemption UI --- web/src/api/client.js | 14 ++ .../pages/ChatAgent/components/ChatView.jsx | 8 + .../pages/ChatAgent/hooks/useChatMessages.js | 32 ++-- web/src/pages/ChatAgent/utils/api.js | 10 ++ .../Dashboard/components/UserConfigPanel.jsx | 138 +++++++++++++++++- web/src/pages/Dashboard/utils/api.js | 15 ++ 6 files changed, 205 insertions(+), 12 deletions(-) diff --git a/web/src/api/client.js b/web/src/api/client.js index 2b63b595..2c77629f 100644 --- a/web/src/api/client.js +++ b/web/src/api/client.js @@ -31,3 +31,17 @@ api.interceptors.request.use(async (config) => { } return config; }); + +// Enrich 429 errors with structured rate limit info +api.interceptors.response.use( + (response) => response, + (error) => { + if (error.response?.status === 429) { + const detail = error.response.data?.detail || {}; + error.status = 429; + error.rateLimitInfo = typeof detail === 'object' ? detail : {}; + error.retryAfter = parseInt(error.response.headers?.['retry-after'], 10) || null; + } + return Promise.reject(error); + }, +); diff --git a/web/src/pages/ChatAgent/components/ChatView.jsx b/web/src/pages/ChatAgent/components/ChatView.jsx index 3564c016..f2435a34 100644 --- a/web/src/pages/ChatAgent/components/ChatView.jsx +++ b/web/src/pages/ChatAgent/components/ChatView.jsx @@ -777,6 +777,14 @@ function ChatView({ workspaceId, threadId, onBack }) { Agent interrupted. Feel free to provide new instructions. )} + {messageError && !isLoading && ( +
+ {messageError} +
+ )} - updateMessage(prev, assistantMessageId, (msg) => ({ - ...msg, - content: msg.content || 'Failed to send message. Please try again.', - isStreaming: false, - error: true, - })) - ); + // Handle rate limit (429) — show limit message and remove optimistic assistant message + if (err.status === 429) { + const info = err.rateLimitInfo || {}; + const limitMsg = info.type === 'credit_limit' + ? `Daily credit limit reached (${info.used_credits}/${info.credit_limit} credits). Resets at midnight UTC.` + : info.type === 'workspace_limit' + ? `Active workspace limit reached (${info.current}/${info.limit}).` + : info.message || 'Rate limit exceeded. Please try again later.'; + setMessageError(limitMsg); + setMessages((prev) => prev.filter((m) => m.id !== assistantMessageId)); + } else { + console.error('Error sending message:', err); + setMessageError(err.message || 'Failed to send message'); + setMessages((prev) => + updateMessage(prev, assistantMessageId, (msg) => ({ + ...msg, + content: msg.content || 'Failed to send message. Please try again.', + isStreaming: false, + error: true, + })) + ); + } } finally { setIsLoading(false); currentMessageRef.current = null; diff --git a/web/src/pages/ChatAgent/utils/api.js b/web/src/pages/ChatAgent/utils/api.js index 01fe8057..6c7e03da 100644 --- a/web/src/pages/ChatAgent/utils/api.js +++ b/web/src/pages/ChatAgent/utils/api.js @@ -104,6 +104,16 @@ export async function updateThreadTitle(threadId, title) { async function streamFetch(url, opts, onEvent) { const res = await fetch(`${baseURL}${url}`, opts); if (!res.ok) { + // Handle 429 (rate limit) with structured detail + if (res.status === 429) { + let detail = {}; + try { detail = await res.json(); } catch { /* ignore */ } + const err = new Error(detail?.detail?.message || 'Rate limit exceeded'); + err.status = 429; + err.rateLimitInfo = detail?.detail || {}; + err.retryAfter = parseInt(res.headers.get('Retry-After'), 10) || null; + throw err; + } // Handle 404 specifically for history replay (expected for new threads) if (res.status === 404 && url.includes('/replay')) { throw new Error(`HTTP error! status: ${res.status}`); diff --git a/web/src/pages/Dashboard/components/UserConfigPanel.jsx b/web/src/pages/Dashboard/components/UserConfigPanel.jsx index 1076d1af..865e9722 100644 --- a/web/src/pages/Dashboard/components/UserConfigPanel.jsx +++ b/web/src/pages/Dashboard/components/UserConfigPanel.jsx @@ -1,7 +1,7 @@ import React, { useState, useEffect, useRef } from 'react'; import { X, User, LogOut } from 'lucide-react'; import { Input } from '../../../components/ui/input'; -import { updateCurrentUser, getCurrentUser, updatePreferences, getPreferences, uploadAvatar } from '../utils/api'; +import { updateCurrentUser, getCurrentUser, updatePreferences, getPreferences, uploadAvatar, redeemCode, getUsageStatus } from '../utils/api'; import { useAuth } from '../../../contexts/AuthContext'; import ConfirmDialog from './ConfirmDialog'; @@ -30,6 +30,14 @@ function UserConfigPanel({ isOpen, onClose }) { const [analysisFocus, setAnalysisFocus] = useState(''); const [outputStyle, setOutputStyle] = useState(''); + const [plan, setPlan] = useState({ id: 1, name: 'free', display_name: 'Free', rank: 0 }); + const [redeemInput, setRedeemInput] = useState(''); + const [isRedeeming, setIsRedeeming] = useState(false); + const [redeemError, setRedeemError] = useState(null); + const [redeemSuccess, setRedeemSuccess] = useState(null); + + const [usage, setUsage] = useState(null); + const [isSubmitting, setIsSubmitting] = useState(false); const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(null); @@ -78,7 +86,7 @@ function UserConfigPanel({ isOpen, onClose }) { useEffect(() => { if (isOpen) { setIsLoading(true); - Promise.all([loadUserData(), loadPreferencesData()]) + Promise.all([loadUserData(), loadPreferencesData(), loadUsageData()]) .finally(() => setIsLoading(false)); } }, [isOpen]); @@ -90,6 +98,7 @@ function UserConfigPanel({ isOpen, onClose }) { setName(userData.user.name || ''); setTimezone(userData.user.timezone || ''); setLocale(userData.user.locale || ''); + setPlan(userData.user.plan || { id: 1, name: 'free', display_name: 'Free', rank: 0 }); const url = userData.user.avatar_url; const version = userData.user.updated_at; setAvatarUrl(url ? `${url}?v=${version}` : null); @@ -112,6 +121,15 @@ function UserConfigPanel({ isOpen, onClose }) { } catch {} }; + const loadUsageData = async () => { + try { + const data = await getUsageStatus(); + setUsage(data); + } catch { + // Usage data load failed - keep null + } + }; + const handleAvatarChange = async (e) => { const file = e.target.files[0]; if (!file) return; @@ -180,8 +198,36 @@ function UserConfigPanel({ isOpen, onClose }) { onClose(); }; + const handleRedeemCode = async () => { + if (!redeemInput.trim()) return; + setIsRedeeming(true); + setRedeemError(null); + setRedeemSuccess(null); + try { + const result = await redeemCode(redeemInput.trim()); + setRedeemSuccess(result.message); + setRedeemInput(''); + refreshUser(); + await Promise.all([loadUserData(), loadUsageData()]); + } catch (err) { + const detail = err.response?.data?.detail || err.message || 'Failed to redeem code'; + setRedeemError(typeof detail === 'string' ? detail : detail.message || 'Failed to redeem code'); + } finally { + setIsRedeeming(false); + } + }; + + const PLAN_BADGE_COLORS = [ + { backgroundColor: 'var(--color-bg-card)', color: 'var(--color-text-tertiary)', border: '1px solid var(--color-border-muted)' }, + { backgroundColor: 'rgba(59, 130, 246, 0.15)', color: '#3b82f6', border: '1px solid rgba(59, 130, 246, 0.3)' }, + { backgroundColor: 'rgba(234, 179, 8, 0.15)', color: '#eab308', border: '1px solid rgba(234, 179, 8, 0.3)' }, + ]; + const getPlanBadgeStyle = (rank) => PLAN_BADGE_COLORS[Math.min(rank, PLAN_BADGE_COLORS.length - 1)]; + const handleClose = () => { setError(null); + setRedeemError(null); + setRedeemSuccess(null); onClose(); }; @@ -349,6 +395,94 @@ function UserConfigPanel({ isOpen, onClose }) { +
+ +
+ + {plan.display_name || plan.name || 'Free'} + +
+ + {usage && ( +
+ {/* Credits */} +
+
+ Daily Credits + + {usage.credits.limit === -1 + ? 'Unlimited' + : `${usage.credits.used} / ${usage.credits.limit}`} + +
+ {usage.credits.limit !== -1 && ( +
+
0.9 + ? 'var(--color-loss)' + : 'var(--color-accent-primary)', + }} + /> +
+ )} +
+ + {/* Workspaces */} +
+
+ Active Workspaces + + {usage.workspaces.limit === -1 + ? 'Unlimited' + : `${usage.workspaces.active} / ${usage.workspaces.limit}`} + +
+
+
+ )} + +
+ { setRedeemInput(e.target.value); setRedeemError(null); setRedeemSuccess(null); }} + placeholder="Enter redemption code" + className="flex-1 rounded-md px-3 py-1.5 text-sm" + style={{ + backgroundColor: 'var(--color-bg-card)', + border: '1px solid var(--color-border-muted)', + color: 'var(--color-text-primary)', + }} + disabled={isRedeeming} + onKeyDown={(e) => { if (e.key === 'Enter') { e.preventDefault(); handleRedeemCode(); } }} + /> + +
+ {redeemError && ( +

{redeemError}

+ )} + {redeemSuccess && ( +

{redeemSuccess}

+ )} +
+ {error && (

{error}

diff --git a/web/src/pages/Dashboard/utils/api.js b/web/src/pages/Dashboard/utils/api.js index 0657fb76..de662529 100644 --- a/web/src/pages/Dashboard/utils/api.js +++ b/web/src/pages/Dashboard/utils/api.js @@ -167,6 +167,21 @@ export async function updatePreferences(preferences) { return data; } +export async function getUsageStatus() { + const { data } = await api.get('/api/v1/usage'); + return data; +} + +export async function redeemCode(code) { + const { data } = await api.post('/api/v1/usage/redeem', { code }); + return data; +} + +export async function getPlans() { + const { data } = await api.get('/api/v1/plans'); + return data; +} + export async function uploadAvatar(file) { const formData = new FormData(); formData.append('file', file); From c36c571409931639e4b3c7a3ee8dd499aa75d0ea Mon Sep 17 00:00:00 2001 From: Alan Chen Date: Wed, 11 Feb 2026 09:42:11 -0500 Subject: [PATCH 3/6] chore: remove stale frontend submodule and changelog --- frontend | 1 - web/FRONTEND_SSE_ADAPTATION_CHANGELOG.md | 116 ----------------------- 2 files changed, 117 deletions(-) delete mode 160000 frontend delete mode 100644 web/FRONTEND_SSE_ADAPTATION_CHANGELOG.md diff --git a/frontend b/frontend deleted file mode 160000 index e36726c8..00000000 --- a/frontend +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e36726c8525eb7a8c3729e37239870a0ffb16e6a diff --git a/web/FRONTEND_SSE_ADAPTATION_CHANGELOG.md b/web/FRONTEND_SSE_ADAPTATION_CHANGELOG.md deleted file mode 100644 index 8914e6ba..00000000 --- a/web/FRONTEND_SSE_ADAPTATION_CHANGELOG.md +++ /dev/null @@ -1,116 +0,0 @@ -# Frontend SSE Adaptation Changelog - -This document summarizes the frontend changes made to adapt to the backend chat SSE event stream updates. - -## 1. Tool Name Changes (snake_case → PascalCase) - -**File:** `src/pages/ChatAgent/components/ToolCallMessageContent.jsx` - -- **Change:** Updated `FILE_TOOLS` constant to include both PascalCase (new) and snake_case (legacy) tool names for backward compatibility. -- **New tool names:** `Read`, `Write`, `Edit`, `Save`, `ExecuteCode`, `Glob`, `Grep`, `WebFetch`, `WebSearch` -- **Legacy names retained:** `read_file`, `write_file`, `edit_file`, `save_file` (for older history) -- **Why:** Backend now uses LangChain SDK convention (PascalCase). FILE_TOOLS is used to detect file-related tools for opening in the file panel. - ---- - -## 2. Subagent Event Detection - -**Files:** `hooks/utils/streamEventHandlers.js`, `hooks/utils/historyEventHandlers.js` - -### `isSubagentEvent` / `isSubagentHistoryEvent` - -- **Old logic:** `event.agent.startsWith('tools:')` -- **New logic:** Subagent if `agent` contains `:` AND does NOT start with `model:` AND is NOT `"tools"`. -- **Rationale:** Backend convention: - - Main agent: `agent.startsWith("model:")` - - Tool node: `agent === "tools"` - - Subagent: `agent` = `"{type}:{uuid4}"` (e.g., `"research:550e8400-..."`) - ---- - -## 3. Subagent Status Event Handling - -**File:** `hooks/utils/streamEventHandlers.js` – `handleSubagentStatus` - -### Preferred Format (from BackgroundTaskRegistry) - -- `active_tasks`: Array of objects with `id` (display_id), `agent_id` (stable UUID), `description`, `type`, `tool_calls`, `current_tool` -- `completed_tasks`: Array of display_id strings (`"Task-1"`, `"Task-2"`) - -### Fallback Format (legacy) - -- `active_subagents` / `completed_subagents`: Arrays of `agent_id` strings - -### Key Changes - -- Uses `agent_id` as the primary key for cards (not `id`/display_id). -- Passes `displayIdToAgentIdMap` to resolve `completed_tasks` display IDs to `agent_id`s. -- Calls `updateSubagentCard(agentId, {...})` with `agentId` as the card key. -- Stores `displayId` for human-readable UI when available. - ---- - -## 4. Agent ID as Primary Identifier - -**Files:** `hooks/useChatMessages.js`, `hooks/useFloatingCards.js`, `components/ChatView.jsx` - -### Refactoring - -- **`agent_id`** (format `{type}:{uuid4}`) is the stable identifier for subagent cards and event routing. -- **`display_id`** (`"Task-1"`, `"Task-2"`) is used only for UI display. -- Card IDs: `subagent-${agentId}` (e.g., `subagent-research:550e8400-...`). - -### Mappings - -- `agentToTaskMapRef`: Maps `agent_id` → `agent_id`. -- `toolCallIdToTaskIdMapRef`: Maps tool call IDs (from main agent’s `task` tool) → `agent_id`. -- `displayIdToAgentIdMapRef`: Maps display_id → `agent_id` for resolving `completed_tasks`. - -### New Helpers - -- `resolveSubagentIdToAgentId(subagentId)`: Resolves tool call ID or legacy ID to stable `agent_id`. -- `getSubagentHistory(subagentId)`: Returns history including `agentId` for card operations. - ---- - -## 5. History Loading - -**File:** `hooks/useChatMessages.js` - -### Subagent Status in History - -- Handles both preferred and fallback formats. -- Uses `task.agent_id` or `task.agent` for identity. -- Stores subagent history keyed by `agent_id`. - -### Order-Based Matching - -- `historyPendingAgentIdsRef`: Holds `agent_id`s from `subagent_status`. -- `historyPendingTaskToolCallIdsRef`: Holds tool call IDs from `task` tool calls. -- Order-based matching builds `toolCallIdToTaskIdMapRef` so segments and history resolve to the correct `agent_id`. - ---- - -## 6. ChatView and Floating Cards - -**File:** `components/ChatView.jsx` - -- `onOpenSubagentTask`: Uses `resolveSubagentIdToAgentId` to convert segment `subagentId` (possibly a tool call ID) to `agent_id` before calling `updateSubagentCard` and `setSelectedAgentId`. -- Agent panel: Uses `displayId` when available for tab labels (e.g., `"Task-1"`). - -**File:** `hooks/useFloatingCards.js` - -- `updateSubagentCard(agentId, ...)`: First parameter is `agent_id`, not display_id. -- Stores `agentId` and `displayId` in `subagentData` for UI and routing. - ---- - -## Summary - -| Area | Old | New | -|------|-----|-----| -| Tool names | snake_case | PascalCase (with legacy fallback) | -| Subagent detection | `agent.startsWith('tools:')` | `agent` contains `:` and not `model:` and not `tools` | -| Subagent identity | display_id (`"Task-1"`) | `agent_id` (`"type:uuid4"`) | -| Card key | task ID / display ID | `agent_id` | -| `subagent_status` | `active_subagents` / `completed_subagents` | `active_tasks` / `completed_tasks` (preferred) with fallback | From 93da34e79ba1541f12d6c253935d4c2778d5ca13 Mon Sep 17 00:00:00 2001 From: Alan Chen Date: Wed, 11 Feb 2026 09:42:17 -0500 Subject: [PATCH 4/6] fix(server): deduplicate utils __init__ exports --- src/server/utils/__init__.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/server/utils/__init__.py b/src/server/utils/__init__.py index 8c11fca8..ea47d53b 100644 --- a/src/server/utils/__init__.py +++ b/src/server/utils/__init__.py @@ -1,11 +1,6 @@ """Server utility functions.""" from .api import CurrentUserId, handle_api_exceptions, raise_not_found -from .checkpoint_helpers import ( - build_checkpoint_config, - get_checkpointer, - require_checkpointer, -) from .checkpointer import ( close_checkpointer_pool, get_checkpointer, @@ -17,13 +12,10 @@ __all__ = [ "CurrentUserId", "UpdateQueryBuilder", - "build_checkpoint_config", "close_checkpointer_pool", "deduplicate_agent_messages", "get_checkpointer", - "get_checkpointer", "handle_api_exceptions", "open_checkpointer_pool", "raise_not_found", - "require_checkpointer", ] From a44f3403b4fadae1113d4ad15b67a646bd10d843 Mon Sep 17 00:00:00 2001 From: Alan Chen Date: Wed, 11 Feb 2026 12:12:31 -0500 Subject: [PATCH 5/6] feat(server): add BYOK API keys and per-user model preferences --- scripts/migrate.py | 28 +- scripts/migrations/003_user_api_keys.sql | 19 + src/llms/llm.py | 184 +++++----- src/llms/manifest/models.json | 441 +---------------------- src/llms/manifest/providers.json | 432 ++-------------------- src/server/app/api_keys.py | 164 +++++++++ src/server/app/chat.py | 127 ++++++- src/server/app/setup.py | 2 + src/server/app/usage.py | 6 + src/server/app/users.py | 13 +- src/server/database/api_keys.py | 171 +++++++++ src/server/dependencies/usage_limits.py | 44 ++- 12 files changed, 672 insertions(+), 959 deletions(-) create mode 100644 scripts/migrations/003_user_api_keys.sql create mode 100644 src/server/app/api_keys.py create mode 100644 src/server/database/api_keys.py diff --git a/scripts/migrate.py b/scripts/migrate.py index 0dad72fa..69606ad2 100644 --- a/scripts/migrate.py +++ b/scripts/migrate.py @@ -100,9 +100,31 @@ async def run_migrations(): sql = migration_file.read_text() try: - # Split and execute statements separately - # (psycopg3 doesn't support multiple statements in one execute) - statements = [s.strip() for s in sql.split(';') if s.strip() and not s.strip().startswith('--')] + # Split SQL into individual statements, respecting + # parenthesized blocks (e.g. CREATE TABLE (...;)). + # Only split on ';' at top-level (depth == 0). + statements = [] + buf = [] + depth = 0 + for line in sql.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith('--'): + continue + depth += stripped.count('(') - stripped.count(')') + if stripped.endswith(';') and depth <= 0: + buf.append(stripped[:-1]) # drop trailing ; + stmt = ' '.join(buf).strip() + if stmt: + statements.append(stmt) + buf = [] + depth = 0 + else: + buf.append(stripped) + # Catch trailing statement without semicolon + trailing = ' '.join(buf).strip() + if trailing: + statements.append(trailing) + for stmt in statements: await cur.execute(stmt) await cur.execute( diff --git a/scripts/migrations/003_user_api_keys.sql b/scripts/migrations/003_user_api_keys.sql new file mode 100644 index 00000000..ae3d661b --- /dev/null +++ b/scripts/migrations/003_user_api_keys.sql @@ -0,0 +1,19 @@ +-- Migration 003: User API keys for BYOK (Bring Your Own Key) support +-- Purpose: Allow users to provide their own LLM API keys to bypass credit limits +-- Requires: pgcrypto extension for symmetric encryption of API keys at rest + +CREATE EXTENSION IF NOT EXISTS pgcrypto; + +-- 1. Per-provider API keys (one row per user+provider) +CREATE TABLE IF NOT EXISTS user_api_keys ( + user_id VARCHAR(255) + REFERENCES users(user_id) ON DELETE CASCADE ON UPDATE CASCADE, + provider VARCHAR(50) NOT NULL, + api_key BYTEA NOT NULL, -- encrypted via pgp_sym_encrypt + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW(), + PRIMARY KEY (user_id, provider) +); + +-- 2. BYOK toggle on users table (global per-user switch) +ALTER TABLE users ADD COLUMN IF NOT EXISTS byok_enabled BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/src/llms/llm.py b/src/llms/llm.py index e6092a96..3ddf6dab 100644 --- a/src/llms/llm.py +++ b/src/llms/llm.py @@ -13,7 +13,7 @@ class ModelConfig: """Manages model configuration from JSON files.""" - + def __init__(self): # Load models.json for model parameters llm_config_path = Path(__file__).parent / "manifest" / "models.json" @@ -24,15 +24,15 @@ def __init__(self): manifest_path = Path(__file__).parent / "manifest" / "providers.json" with open(manifest_path, 'r') as f: self.manifest = json.load(f) - + def get_model_config(self, model_id: str) -> Optional[Dict]: """Get model configuration from llm_config.""" return self.llm_config.get(model_id) - + def get_provider_info(self, provider: str) -> Dict: """Get provider configuration from manifest.""" return self.manifest["provider_config"].get(provider, {}) - + def get_model_pricing(self, custom_model_name: str) -> Optional[Dict[str, Any]]: """Get pricing information for a specific model from manifest.""" # Get model info from llm_config first @@ -66,43 +66,55 @@ def get_model_info(self, provider: str, model_id: str) -> Optional[Dict[str, Any return model return None + def get_byok_eligible_providers(self) -> list[str]: + """Return list of provider names that have byok_eligible=true in manifest.""" + return [ + name + for name, cfg in self.manifest.get("provider_config", {}).items() + if cfg.get("byok_eligible", False) + ] + class LLM: """Factory class for creating LangChain LLM clients.""" - + # Class-level model config instance _model_config = None - + @classmethod def get_model_config(cls) -> ModelConfig: """Get or create the model configuration singleton.""" if cls._model_config is None: cls._model_config = ModelConfig() return cls._model_config - - def __init__(self, model: str, **override_params): + + def __init__(self, model: str, api_key: str | None = None, **override_params): """ Initializes the LLM factory. Args: model: The customized model name (key in llm_config.json). + api_key: Optional API key override (e.g. from BYOK). **override_params: Additional parameters to override defaults. """ self.model_config = self.get_model_config() - + # Get model configuration from models.json model_info = self.model_config.get_model_config(model) if not model_info: raise ValueError(f"Model {model} not found in models.json") - + self.custom_model_name = model # Store the custom name self.model = model_info["model_id"] # Use model_id for API calls self.provider = model_info["provider"] self.parameters = model_info.get("parameters", {}).copy() - + # Override with any provided parameters self.parameters.update(override_params) + # Store optional API key override (BYOK) + self.api_key_override = api_key + # Get provider info from manifest self.provider_info = self.model_config.get_provider_info(self.provider) @@ -137,36 +149,41 @@ def get_llm(self): return self._get_gemini_llm() else: raise ValueError(f"Unsupported SDK: {self.sdk} for provider {self.provider}") - + + def _resolve_api_key(self) -> str: + """Resolve API key: BYOK override > env var > local fallback.""" + if self.api_key_override: + return self.api_key_override + if self.env_key: + key = os.getenv(self.env_key) + if not key: + raise ValueError(f"{self.env_key} environment variable is not set") + return key + return "lm-studio" if self.provider == "lm-studio" else "EMPTY" + + def _resolve_base_url(self, param_name: str = "base_url") -> dict: + """Resolve base URL with HOST_IP substitution. Returns dict to merge into params.""" + if not self.base_url: + return {} + url = self.base_url + if "{HOST_IP}" in url: + host_ip = os.getenv("HOST_IP") + if not host_ip: + raise ValueError(f"HOST_IP environment variable is not set for {self.provider}") + url = url.replace("{HOST_IP}", host_ip) + return {param_name: url} + def _get_openai_llm(self): """Get OpenAI or OpenAI-compatible LLM.""" params = { "model": self.model, + "api_key": self._resolve_api_key(), "stream_usage": True, "max_retries": 5, - "timeout": 600.0, # 10 minutes - sufficient for long reasoning + "timeout": 600.0, } + params.update(self._resolve_base_url("base_url")) - # Set API key from provider configuration - if self.env_key: - params["api_key"] = os.getenv(self.env_key) - if not params["api_key"]: - raise ValueError(f"{self.env_key} environment variable is not set") - else: - # Special case for local providers without API key - params["api_key"] = "lm-studio" if self.provider == "lm-studio" else "EMPTY" - - # Set base URL from provider configuration - if self.base_url: - # Handle HOST_IP replacement for local providers - if "{HOST_IP}" in self.base_url: - host_ip = os.getenv("HOST_IP") - if not host_ip: - raise ValueError(f"HOST_IP environment variable is not set for {self.provider}") - params["base_url"] = self.base_url.replace("{HOST_IP}", host_ip) - else: - params["base_url"] = self.base_url - # Handle Response API if configured if self.use_response_api: params["output_version"] = "responses/v1" @@ -177,38 +194,19 @@ def _get_openai_llm(self): # Add all parameters from llm_config params.update(self.parameters) - + return ChatOpenAI(**params) def _get_deepseek_llm(self): """Get DeepSeek or DeepSeek-compatible LLM.""" params = { "model": self.model, + "api_key": self._resolve_api_key(), "stream_usage": True, "max_retries": 5, - "timeout": 600.0, # 10 minutes - sufficient for long reasoning + "timeout": 600.0, } - - # Set API key from provider configuration - if self.env_key: - params["api_key"] = os.getenv(self.env_key) - if not params["api_key"]: - raise ValueError(f"{self.env_key} environment variable is not set") - else: - # Special case for local providers without API key - params["api_key"] = "EMPTY" - - # Set base URL from provider configuration (ChatDeepSeek uses api_base) - if self.base_url: - # Handle HOST_IP replacement for local providers - if "{HOST_IP}" in self.base_url: - host_ip = os.getenv("HOST_IP") - if not host_ip: - raise ValueError(f"HOST_IP environment variable is not set for {self.provider}") - params["api_base"] = self.base_url.replace("{HOST_IP}", host_ip) - else: - params["api_base"] = self.base_url - + params.update(self._resolve_base_url("api_base")) # Add all parameters from llm_config params.update(self.parameters) @@ -219,31 +217,12 @@ def _get_qwq_llm(self): """Get QwQ or QwQ-compatible LLM (for Qwen models with reasoning support).""" params = { "model": self.model, + "api_key": self._resolve_api_key(), "stream_usage": True, "max_retries": 5, - "timeout": 600.0, # 10 minutes - sufficient for long reasoning + "timeout": 600.0, } - - # Set API key from provider configuration - if self.env_key: - params["api_key"] = os.getenv(self.env_key) - if not params["api_key"]: - raise ValueError(f"{self.env_key} environment variable is not set") - else: - # Special case for local providers without API key - params["api_key"] = "EMPTY" - - # Set base URL from provider configuration (ChatQwQ uses api_base) - if self.base_url: - # Handle HOST_IP replacement for local providers - if "{HOST_IP}" in self.base_url: - host_ip = os.getenv("HOST_IP") - if not host_ip: - raise ValueError(f"HOST_IP environment variable is not set for {self.provider}") - params["api_base"] = self.base_url.replace("{HOST_IP}", host_ip) - else: - params["api_base"] = self.base_url - + params.update(self._resolve_base_url("api_base")) # Add all parameters from llm_config params.update(self.parameters) @@ -254,9 +233,12 @@ def _get_anthropic_llm(self): """Get Anthropic LLM.""" from langchain_anthropic import ChatAnthropic + # Set API key: prefer BYOK override, then env var + api_key = self.api_key_override or (os.getenv(self.env_key) if self.env_key else None) + params = { "model": self.model, - "api_key": os.getenv(self.env_key) if self.env_key else None, + "api_key": api_key, "max_retries": 5, "timeout": 600.0, # 10 minutes - sufficient for long reasoning } @@ -275,39 +257,44 @@ def _get_anthropic_llm(self): params.update(filtered_params) return ChatAnthropic(**params) - + def _get_gemini_llm(self): """Get Gemini LLM.""" from langchain_google_genai import ChatGoogleGenerativeAI + + # Set API key: prefer BYOK override, then env var + api_key = self.api_key_override or (os.getenv(self.env_key) if self.env_key else None) + params = { "model": self.model, - "api_key": os.getenv(self.env_key) if self.env_key else None, + "api_key": api_key, "timeout": 600.0, # 10 minutes - sufficient for long reasoning } if not params["api_key"]: raise ValueError(f"{self.env_key or 'GEMINI_API_KEY'} environment variable is not set") - + # Add all parameters from llm_config params.update(self.parameters) - + return ChatGoogleGenerativeAI(**params) # Backward compatibility functions -def create_llm(model: str, **kwargs): +def create_llm(model: str, api_key: str | None = None, **kwargs): """ Convenience function for creating an LLM instance. - + Args: model: The model name + api_key: Optional API key override (e.g. from BYOK) **kwargs: Additional parameters to override - + Returns: A LangChain chat model instance """ - return LLM(model, **kwargs).get_llm() + return LLM(model, api_key=api_key, **kwargs).get_llm() def get_llm_by_type(llm_type: str) -> BaseChatModel: @@ -330,19 +317,19 @@ def get_llm_by_type(llm_type: str) -> BaseChatModel: def get_configured_llm_models() -> dict[str, list[str]]: """ - Get all configured LLM models grouped by provider. + Get visible LLM models grouped by provider. + + Only returns models with "visible": true in models.json. Returns: - Dictionary mapping provider to list of configured model names. + Dictionary mapping provider to list of visible model names. """ try: config = ModelConfig() models: dict[str, list[str]] = {} - # Group all models by provider - for model_name in config.llm_config.keys(): - model_info = config.get_model_config(model_name) - if model_info: + for model_name, model_info in config.llm_config.items(): + if model_info and model_info.get("visible", False): provider = model_info.get("provider", "unknown") models.setdefault(provider, []).append(model_name) @@ -352,7 +339,7 @@ def get_configured_llm_models() -> dict[str, list[str]]: # Log error and return empty dict to avoid breaking the application print(f"Warning: Failed to load LLM configuration: {e}") return {} - + def should_enable_caching(model_name: str) -> bool: """ Check if a model should enable Anthropic prompt caching. @@ -374,12 +361,3 @@ def should_enable_caching(model_name: str) -> bool: return parameters.get("enable_caching", False) except Exception: return False - - -## Important Note: -# 1. The models.json file (src/llms/manifest/models.json) is used to store the detailed model configuration and name mapping. -# 2. The providers.json file (src/llms/manifest/providers.json) is used: -# - to store the model pricing information and model parameters -# - to store the model parameters from the model providers. -# - to store the providers information including the SDK, base URL, environment key. -# We assume all the configurations in models.json are valid and complete - always validate the configurations when adding new models. \ No newline at end of file diff --git a/src/llms/manifest/models.json b/src/llms/manifest/models.json index 5393a2d8..3935dcbd 100644 --- a/src/llms/manifest/models.json +++ b/src/llms/manifest/models.json @@ -1,34 +1,4 @@ { - "gpt-5": { - "model_id": "gpt-5", - "provider": "openai", - "parameters": { - "reasoning": { - "effort": "minimal", - "summary": "auto" - } - } - }, - "gpt-5-high": { - "model_id": "gpt-5", - "provider": "openai", - "parameters": { - "reasoning": { - "effort": "high", - "summary": "auto" - } - } - }, - "gpt-5-medium": { - "model_id": "gpt-5", - "provider": "openai", - "parameters": { - "reasoning": { - "effort": "medium", - "summary": "auto" - } - } - }, "gpt-5-mini": { "model_id": "gpt-5-mini", "provider": "openai", @@ -49,13 +19,6 @@ } } }, - "gpt-5-mini-medium-open": { - "model_id": "gpt-5-mini", - "provider": "openrouter", - "parameters": { - "reasoning_effort": "medium" - } - }, "gpt-5-nano": { "model_id": "gpt-5-nano", "provider": "openai", @@ -66,13 +29,6 @@ } } }, - "gpt-5-nano-open": { - "model_id": "gpt-5-nano", - "provider": "openrouter", - "parameters": { - "reasoning_effort": "minimal" - } - }, "gpt-5-nano-medium": { "model_id": "gpt-5-nano", "provider": "openai", @@ -83,16 +39,6 @@ } } }, - "gpt-5.1-codex": { - "model_id": "gpt-5.1-codex", - "provider": "openai", - "parameters": { - "reasoning": { - "effort": "minimal", - "summary": "auto" - } - } - }, "gpt-5.1-codex-mini-low": { "model_id": "gpt-5.1-codex-mini", "provider": "openai", @@ -143,85 +89,13 @@ } } }, - "gpt-4.1": { - "model_id": "gpt-4.1", - "provider": "openai", - "parameters": { - "temperature": 0, - "top_p": 1, - "presence_penalty": 0, - "frequency_penalty": 0 - } - }, - "gpt-4.1-mini": { - "model_id": "gpt-4.1-mini", - "provider": "openai", - "parameters": { - "temperature": 0, - "top_p": 1 - } - }, - "gpt-4.1-nano": { - "model_id": "gpt-4.1-nano", - "provider": "openai", - "parameters": { - "temperature": 0 - } - }, - "gpt-4o": { - "model_id": "gpt-4o", - "provider": "openai", - "parameters": { - "temperature": 0, - "top_p": 1 - } - }, - "gpt-4o-mini": { - "model_id": "gpt-4o-mini", - "provider": "openai", - "parameters": { - "temperature": 0 - } - }, - "o3": { - "model_id": "o3", - "provider": "openai", - "parameters": { - "reasoning": { - "effort": "low", - "summary": "auto" - }, - "max_tokens": 16384 - } - }, - "o4-mini-open": { - "model_id": "o4-mini", - "provider": "openrouter", - "parameters": { - "reasoning_effort": "low" - } - }, - "claude-opus-4-5-20251101": { - "model_id": "claude-opus-4-5-20251101", - "provider": "anthropic", - "parameters": { - "temperature": 0, - "max_tokens": 4096, - "top_p": 1, - "thinking": { - "type": "disabled" - }, - "enable_caching": true - } - }, - "claude-opus-4": { - "model_id": "claude-opus-4-20250522", + "claude-opus-4-6": { + "model_id": "claude-opus-4-6", "provider": "anthropic", "parameters": { "thinking": { - "type": "disabled" - }, - "enable_caching": true + "type": "adaptive" + } } }, "claude-sonnet-4-5": { @@ -236,22 +110,10 @@ "enable_caching": true } }, - "claude-sonnet-4": { - "model_id": "claude-sonnet-4-20250522", - "provider": "anthropic", - "parameters": { - "temperature": 0, - "max_tokens": 4096, - "top_p": 1, - "thinking": { - "type": "disabled" - }, - "enable_caching": true - } - }, "gemini-2.5-pro": { "model_id": "gemini-2.5-pro", "provider": "gemini", + "visible": true, "parameters": { "thinking_budget": 0, "temperature": 0, @@ -262,6 +124,7 @@ "gemini-2.5-flash": { "model_id": "gemini-2.5-flash", "provider": "gemini", + "visible": true, "parameters": { "thinking_budget": 0, "temperature": 0 @@ -270,6 +133,7 @@ "gemini-2.5-flash-lite": { "model_id": "gemini-2.5-flash-lite", "provider": "gemini", + "visible": true, "parameters": { "thinking_budget": 0, "temperature": 0, @@ -280,22 +144,25 @@ "gemini-3-pro": { "model_id": "gemini-3-pro-preview", "provider": "gemini", + "visible": true, "parameters": { - "include_thoughts": true, + "include_thoughts": true, "thinking_level": "high" } }, "gemini-3-flash": { "model_id": "gemini-3-flash-preview", "provider": "gemini", + "visible": true, "parameters": { - "include_thoughts": true, + "include_thoughts": true, "thinking_level": "medium" } }, "gemini-3-pro-image": { "model_id": "gemini-3-pro-image-preview", - "provider": "gemini" + "provider": "gemini", + "visible": true }, "gpt-oss-120b": { "model_id": "gpt-oss-120b", @@ -340,105 +207,13 @@ "reasoning_effort": "medium" } }, - "doubao-seed-thinking": { - "model_id": "doubao-seed-1-6-thinking-250615", - "provider": "volcengine", - "extra_body": { - "caching": {"type": "enabled"}, - "thinking": {"type": "enabled"} - } - }, - "reasoning-model": { - "model_id": "doubao-seed-1-6-thinking-250615", - "provider": "volcengine", - "extra_body": { - "caching": {"type": "enabled"}, - "thinking": {"type": "enabled"} - } - }, - "basic": { - "model_id": "doubao-seed-1.6-251015", - "provider": "volcengine", - "extra_body": { - "caching": {"type": "enabled"}, - "thinking": {"type": "disabled"} - } - }, - "reasoning": { - "model_id": "doubao-seed-1-6-thinking-250615", - "provider": "volcengine", - "extra_body": { - "caching": {"type": "enabled"}, - "thinking": {"type": "enabled"} - } - }, - "vision": { - "model_id": "doubao-seed-1-6-vision-250615", - "provider": "volcengine", - "extra_body": { - "caching": {"type": "enabled"}, - "thinking": {"type": "enabled"} - } - }, - "doubao_seed_1_6_thinking": { - "model_id": "doubao-seed-1-6-thinking-250715", - "provider": "volcengine", - "parameters": { - "store": true, - "max_tokens": 32000 - }, - "extra_body": { - "caching": {"type": "enabled"}, - "thinking": {"type": "enabled"} - } - }, - "doubao-seed-1.6-vision": { - "model_id": "doubao-seed-1-6-vision-250815", - "provider": "volcengine", - "parameters": { - "store": true, - "max_tokens": 32000 - }, - "extra_body": { - "caching": {"type": "enabled"}, - "thinking": {"type": "enabled"} - } - }, - "doubao-seed-1.6": { - "model_id": "doubao-seed-1-6-251015", - "provider": "volcengine", - "parameters": { - "store": true, - "max_tokens": 32000 - }, - "extra_body": { - "thinking": {"type": "enabled"}, - "reasoning_effort": "medium", - "caching": {"type": "enabled"} - } - }, "doubao-seed-1.8": { "model_id": "doubao-seed-1-8-251228", "provider": "volcengine", "parameters": { "store": true, "reasoning": { - "effort": "high" - } - }, - "extra_body": { - "thinking": {"type": "enabled"}, - "caching": {"type": "enabled" - } - } - }, - "doubao-seed-1.8-non-think": { - "model_id": "doubao-seed-1-8-251228", - "provider": "volcengine", - "parameters": { - "store": true, - "reasoning": { - "effort": "minimal" + "effort": "medium" } }, "extra_body": { @@ -446,37 +221,6 @@ "caching": {"type": "enabled"} } }, - "doubao-seed-1.6-flash": { - "model_id": "doubao-seed-1-6-flash-250615", - "provider": "volcengine", - "parameters": { - "store": true, - "max_tokens": 1500 - }, - "extra_body": { - "caching": {"type": "enabled"}, - "thinking": {"type": "disabled"} - } - }, - "doubao-seed-translation": { - "model_id": "doubao-seed-translation", - "provider": "volcengine" - }, - "kimi-k2-thinking": { - "model_id": "kimi-k2-thinking", - "provider": "moonshot" - }, - "deepseek-v3.1": { - "model_id": "deepseek-v3-1-terminus", - "provider": "volcengine", - "parameters": { - "store": true - }, - "extra_body": { - "caching": {"type": "enabled"}, - "thinking": {"type": "enabled"} - } - }, "deepseek-reasoner": { "model_id": "deepseek-reasoner", "provider": "deepseek", @@ -484,50 +228,6 @@ "temperature": 0 } }, - "tongyi-deepresearch-30b": { - "model_id": "alibaba/tongyi-deepresearch-30b-a3b", - "provider": "openrouter", - "parameters": { - "temperature": 0 - } - }, - "qwen3-coder": { - "model_id": "qwen/qwen3-coder", - "provider": "openrouter", - "parameters": { - "temperature": 0 - } - }, - "grok-4-fast": { - "model_id": "x-ai/grok-4-fast:free", - "provider": "openrouter", - "parameters": { - "temperature": 0, - "reasoning": { - "effort": "medium" - } - } - }, - "grok-4-fast-low": { - "model_id": "x-ai/grok-4-fast:free", - "provider": "openrouter", - "parameters": { - "temperature": 0, - "reasoning": { - "effort": "low" - } - } - }, - "grok-4-fast-high": { - "model_id": "x-ai/grok-4-fast:free", - "provider": "openrouter", - "parameters": { - "temperature": 0, - "reasoning": { - "effort": "high" - } - } - }, "glm-4.7": { "model_id": "glm-4.7", "provider": "z-ai", @@ -568,44 +268,6 @@ } } }, - "doubao-seed-code-responses": { - "model_id": "doubao-seed-code-preview-251028", - "provider": "volcengine", - "parameters": { - "temperature": 0, - "max_tokens": 32000, - "store": true - }, - "extra_body": { - "caching": {"type": "enabled"}, - "thinking": {"type": "disabled"} - } - }, - "doubao-seed-code": { - "model_id": "doubao-seed-code-preview-251028", - "provider": "doubao-anthropic", - "parameters": { - "temperature": 0, - "max_tokens": 32000, - "thinking": { - "type": "disabled" - } - } - }, - "qwen3-vl-plus": { - "model_id": "qwen3-vl-plus", - "provider": "dashscope-sg", - "parameters":{ - "enable_thinking": true - } - }, - "qwen3-vl-flash": { - "model_id": "qwen3-vl-flash", - "provider": "dashscope-sg", - "parameters":{ - "enable_thinking": true - } - }, "qwen3-max": { "model_id": "qwen3-max", "provider": "dashscope" @@ -617,34 +279,6 @@ "enable_thinking": true } }, - "qwen-plus-80b": { - "model_id": "qwen-plus-2025-09-11", - "provider": "dashscope", - "parameters":{ - "enable_thinking": true - } - }, - "qwen-plus": { - "model_id": "qwen-plus", - "provider": "dashscope", - "parameters":{ - "enable_thinking": true - } - }, - "qwen-flash": { - "model_id": "qwen-flash", - "provider": "dashscope", - "parameters":{ - "enable_thinking": true - } - }, - "qwen-flash-non-think": { - "model_id": "qwen-flash", - "provider": "dashscope", - "parameters":{ - "enable_thinking": false - } - }, "qwen3-max-anthropic": { "model_id": "qwen3-max", "provider": "dashscope-anthropic" @@ -656,20 +290,6 @@ "enable_thinking": true } }, - "qwen-plus-80b-anthropic": { - "model_id": "qwen-plus-2025-09-11", - "provider": "dashscope-anthropic", - "extra_body": { - "enable_thinking": true - } - }, - "qwen-flash-anthropic": { - "model_id": "qwen-flash", - "provider": "dashscope-anthropic", - "extra_body": { - "enable_thinking": true - } - }, "embedding-small": { "model_id": "text-embedding-3-small", "provider": "openai", @@ -691,37 +311,6 @@ "dimensions": 1536 } }, - "claude-sonnet-4-5-proxy": { - "model_id": "claude-sonnet-4-5-20250929", - "provider": "anthropic-proxy", - "parameters": { - "max_tokens": 16000, - "thinking": { - "type": "enabled" - }, - "enable_caching": true - } - }, - "claude-opus-4-5-proxy": { - "model_id": "claude-opus-4-5-20251101", - "provider": "anthropic-proxy", - "parameters": { - "thinking": { - "type": "enabled" - }, - "enable_caching": true - } - }, - "claude-haiku-4-5-proxy": { - "model_id": "claude-haiku-4-5-20251001", - "provider": "anthropic-proxy", - "parameters": { - "thinking": { - "type": "enabled" - }, - "enable_caching": true - } - }, "kimi": { "model_id": "kimi-for-coding", "provider": "moonshot", @@ -731,4 +320,4 @@ } } } -} \ No newline at end of file +} diff --git a/src/llms/manifest/providers.json b/src/llms/manifest/providers.json index 7851c6ca..df35dd89 100644 --- a/src/llms/manifest/providers.json +++ b/src/llms/manifest/providers.json @@ -40,19 +40,6 @@ }, "models": { "openai": [ - { - "id": "gpt-5", - "name": "GPT-5", - "is_reasoning": true, - "description": "OpenAI's most advanced reasoning model", - "alias": ["gpt-5-2025-08-07"], - "pricing": { - "input": 1.25, - "cached_input": 0.125, - "output": 10.00, - "unit": "per_1m_tokens" - } - }, { "id": "gpt-5-mini", "name": "GPT-5 Mini", @@ -128,75 +115,15 @@ "output": 14.00, "unit": "per_1m_tokens" } - }, - { - "id": "gpt-4.1", - "name": "GPT-4.1", - "is_reasoning": false, - "description": "GPT-4.1 model", - "pricing": { - "input": 2.00, - "cached_input": 0.50, - "output": 8.00, - "unit": "per_1m_tokens" - } - }, - { - "id": "gpt-4.1-mini", - "name": "GPT-4.1 Mini", - "is_reasoning": false, - "description": "Smaller GPT-4.1 variant", - "pricing": { - "input": 0.40, - "cached_input": 0.10, - "output": 1.60, - "unit": "per_1m_tokens" - } - }, - { - "id": "gpt-4.1-nano", - "name": "GPT-4.1 Nano", - "is_reasoning": false, - "description": "Smallest GPT-4.1 variant", - "pricing": { - "input": 0.10, - "cached_input": 0.025, - "output": 0.40, - "unit": "per_1m_tokens" - } - }, - { - "id": "gpt-4o", - "name": "GPT-4o", - "is_reasoning": false, - "description": "Optimized GPT-4 variant" - }, - { - "id": "gpt-4o-mini", - "name": "GPT-4o Mini", - "is_reasoning": false, - "description": "Smaller GPT-4o variant" - }, - { - "id": "o3", - "name": "O3", - "is_reasoning": true, - "description": "Advanced reasoning model for complex multi-faceted analysis" - }, - { - "id": "o4-mini", - "name": "O4 Mini", - "is_reasoning": true, - "description": "Fast, cost-efficient reasoning model" } ], "anthropic": [ - { - "id": "claude-opus-4-5-20251101", - "name": "Claude 4.5 Opus", - "is_reasoning": false, - "description": "Most intelligent model for building agents and coding", - "alias": ["claude-opus-4.5"], + { + "id": "claude-opus-4-6", + "name": "Claude 4.6 Opus", + "is_reasoning": true, + "description": "Claude 4.6 Opus model", + "alias": ["claude-opus-4.6"], "pricing": { "input": 5.00, "cached_input": 0.50, @@ -205,21 +132,6 @@ "unit": "per_1m_tokens" } }, - { - "id": "claude-opus-4-20250522", - "name": "Claude 4 Opus", - "is_reasoning": false, - "description": "Claude 4 Opus model", - "alias": ["claude-opus-4"], - "pricing": { - "input": 15.00, - "cached_input": 1.50, - "cache_5m": 18.75, - "cache_1h": 30.00, - "output": 75.00, - "unit": "per_1m_tokens" - } - }, { "id": "claude-sonnet-4-5-20250929", "name": "Claude 4.5 Sonnet", @@ -240,21 +152,6 @@ "cache_1h": 7.50, "unit": "per_1m_tokens" } - }, - { - "id": "claude-sonnet-4-20250522", - "name": "Claude 4 Sonnet", - "is_reasoning": false, - "description": "Claude 4 Sonnet model", - "alias": ["claude-sonnet-4"], - "pricing": { - "input": 3.00, - "cached_input": 0.30, - "cache_5m": 3.75, - "cache_1h": 6.00, - "output": 15.00, - "unit": "per_1m_tokens" - } } ], "gemini": [ @@ -301,6 +198,19 @@ "unit": "per_1m_tokens" } }, + { + "id": "gemini-3-flash-preview", + "name": "Gemini 3 Flash Preview", + "is_reasoning": true, + "description": "Google's fast and cost-efficient model with thinking capabilities", + "pricing": { + "input": 0.50, + "cached_input": 0.05, + "output": 3.00, + "storage": 1.00, + "unit": "per_1m_tokens" + } + }, { "id": "gemini-3-pro-preview", "name": "Gemini 3 Pro Preview", @@ -350,34 +260,6 @@ } } ], - "moonshot": [ - { - "id": "kimi-k2-0905-preview", - "name": "Kimi K2 0905 Preview", - "is_reasoning": false, - "description": "MoonshotAI's Kimi K2 model with prompt caching and large context window", - "pricing": { - "input": 0.60, - "cached_input": 0.15, - "output": 2.50, - "unit": "per_1m_tokens" - }, - "context_window": 262144 - }, - { - "id": "kimi-k2-thinking", - "name": "Kimi K2 0905 Preview", - "is_reasoning": false, - "description": "MoonshotAI's Kimi K2 model with thinking capabilities", - "pricing": { - "input": 0.60, - "cached_input": 0.15, - "output": 2.50, - "unit": "per_1m_tokens" - }, - "context_window": 262144 - } - ], "vllm": [ { "id": "gpt-oss-120b", @@ -484,67 +366,7 @@ } } ], - "doubao-anthropic": [ - { - "id": "doubao-seed-code-preview-251028", - "name": "Doubao Seed Code", - "is_reasoning": true, - "description": "Code-optimized reasoning model with extended thinking capabilities and input-dependent tiered pricing", - "pricing": { - "input_tiers": [ - {"max_tokens": 32000, "rate": 0.17}, - {"max_tokens": 128000, "rate": 0.20}, - {"max_tokens": null, "rate": 0.40} - ], - "output_tiers": [ - {"max_tokens": 32000, "rate": 1.14}, - {"max_tokens": 128000, "rate": 1.71}, - {"max_tokens": null, "rate": 2.29} - ], - "output_pricing_mode": "input_dependent", - "cached_input": 0.034, - "cache_storage": 0.0024, - "unit": "per_1m_tokens" - } - } - ], "volcengine": [ - { - "id": "doubao-seed-1.6-vision", - "name": "Doubao Seed 1.6 Vision", - "is_reasoning": true, - "description": "Vision model with thinking capabilities and tiered pricing (0.8-2.4 input based on token length)", - "alias": ["doubao-seed-1-6-vision-250615"], - "pricing": { - "input_tiers": [ - {"max_tokens": 32000, "rate": 0.11}, - {"max_tokens": 128000, "rate": 0.17}, - {"max_tokens": null, "rate": 0.34} - ], - "cache_storage": 0.0024, - "cache_hit": 0.023, - "output": 1.13, - "unit": "per_1m_tokens" - } - }, - { - "id": "doubao-seed-1.6", - "name": "Doubao Seed 1.6", - "is_reasoning": true, - "description": "Base model with thinking capabilities and tiered pricing (0.8-2.4 input, 2-24 output based on ratios)", - "alias": ["doubao-seed-1-6-251015", "doubao-seed-1-6"], - "pricing": { - "input_tiers": [ - {"max_tokens": 32000, "rate": 0.11}, - {"max_tokens": 128000, "rate": 0.17}, - {"max_tokens": null, "rate": 0.34} - ], - "cache_storage": 0.0024, - "cache_hit": 0.023, - "output": 1.13, - "unit": "per_1m_tokens" - } - }, { "id": "doubao-seed-1.8", "name": "Doubao Seed 1.8", @@ -567,135 +389,9 @@ "cache_hit": 0.022, "unit": "per_1m_tokens" } - }, - { - "id": "doubao-seed-1.6-thinking", - "name": "Doubao Seed 1.6 Thinking", - "is_reasoning": true, - "description": "Thinking-optimized model with tiered pricing (0.8-2.4 input, 8-24 output)", - "alias": ["doubao-seed-1-6-thinking-250615"], - "pricing": { - "input_tiers": [ - {"max_tokens": 32000, "rate": 0.11}, - {"max_tokens": 128000, "rate": 0.17}, - {"max_tokens": null, "rate": 0.34} - ], - "cache_storage": 0.0024, - "cache_hit": 0.023, - "output": 1.13, - "unit": "per_1m_tokens" - } - }, - { - "id": "doubao-seed-1.6-flash", - "name": "Doubao Seed 1.6 Flash", - "is_reasoning": true, - "description": "Fast and cost-efficient model with tiered pricing (0.15-0.6 input, 1.5-6 output)", - "alias": ["doubao-seed-1-6-flash-250615"], - "pricing": { - "input_tiers": [ - {"max_tokens": 32000, "rate": 0.021}, - {"max_tokens": 128000, "rate": 0.042}, - {"max_tokens": null, "rate": 0.085} - ], - "cache_storage": 0.0024, - "cache_hit": 0.0042, - "output": 0.42, - "unit": "per_1m_tokens" - } - }, - { - "id": "deepseek-v3-1-terminus", - "name": "DeepSeek V3.1", - "is_reasoning": true, - "description": "Hybrid reasoning model supporting both thinking and non-thinking inference modes", - "pricing": { - "input": 0.57, - "cached_input": 0.11, - "output": 1.71, - "cache_storage": 0.0024, - "unit": "per_1m_tokens" - } - }, - { - "id": "doubao-seed-translation", - "name": "Doubao Seed Translation", - "is_reasoning": false, - "description": "Multi-language translation model optimized for accuracy and fluency", - "pricing": { - "input": 0.17, - "output": 0.51, - "unit": "per_1m_tokens" - } - }, - { - "id": "doubao-seed-code-preview-251028", - "name": "Doubao Seed Code", - "is_reasoning": true, - "description": "Code-optimized reasoning model with extended thinking capabilities and input-dependent tiered pricing", - "pricing": { - "input_tiers": [ - {"max_tokens": 32000, "rate": 0.17}, - {"max_tokens": 128000, "rate": 0.20}, - {"max_tokens": null, "rate": 0.40} - ], - "output_tiers": [ - {"max_tokens": 32000, "rate": 1.14}, - {"max_tokens": 128000, "rate": 1.71}, - {"max_tokens": null, "rate": 2.29} - ], - "output_pricing_mode": "input_dependent", - "cached_input": 0.034, - "cache_storage": 0.0024, - "unit": "per_1m_tokens" - } } ], "dashscope": [ - { - "id": "qwen3-vl-plus", - "name": "Qwen3 VL Plus", - "is_reasoning": true, - "is_vision": true, - "description": "Alibaba's Qwen3 vision-language model (Plus tier) with thinking capabilities", - "context_window": 262144, - "max_output": 32768, - "pricing": { - "input_tiers": [ - {"max_tokens": 32000, "rate": 0.139, "cached_input": 0.028}, - {"max_tokens": 128000, "rate": 0.208, "cached_input": 0.042}, - {"max_tokens": null, "rate": 0.417, "cached_input": 0.083} - ], - "output_tiers": [ - {"max_tokens": 32000, "rate": 1.39}, - {"max_tokens": 128000, "rate": 2.08}, - {"max_tokens": null, "rate": 4.17} - ], - "unit": "per_1m_tokens" - } - }, - { - "id": "qwen3-vl-flash", - "name": "Qwen3 VL Flash", - "is_reasoning": true, - "is_vision": true, - "description": "Alibaba's Qwen3 vision-language model (Flash tier) - fast and cost-efficient", - "context_window": 262144, - "max_output": 32768, - "pricing": { - "input_tiers": [ - {"max_tokens": 32000, "rate": 0.021, "cached_input": 0.004}, - {"max_tokens": 128000, "rate": 0.042, "cached_input": 0.008}, - {"max_tokens": null, "rate": 0.083, "cached_input": 0.017} - ], - "output_tiers": [ - {"max_tokens": 32000, "rate": 0.21}, - {"max_tokens": 128000, "rate": 0.42}, - {"max_tokens": null, "rate": 0.83} - ], - "unit": "per_1m_tokens" - } - }, { "id": "qwen3-max", "name": "Qwen3 Max", @@ -737,69 +433,6 @@ ], "unit": "per_1m_tokens" } - }, - { - "id": "qwen-plus", - "name": "Qwen Plus a235b", - "is_reasoning": false, - "description": "Alibaba's Qwen Plus - balanced performance and cost", - "context_window": 1000000, - "max_output": 32768, - "pricing": { - "input_tiers": [ - {"max_tokens": 128000, "rate": 0.11, "cached_input": 0.022}, - {"max_tokens": 256000, "rate": 0.33, "cached_input": 0.066}, - {"max_tokens": null, "rate": 0.67, "cached_input": 0.134} - ], - "output_tiers": [ - {"max_tokens": 128000, "rate": 0.28}, - {"max_tokens": 256000, "rate": 2.78}, - {"max_tokens": null, "rate": 6.67} - ], - "unit": "per_1m_tokens" - } - }, - { - "id": "qwen-plus-2025-09-11", - "name": "Qwen Plus next 80b", - "is_reasoning": false, - "description": "Alibaba's Qwen Plus - balanced performance and cost", - "context_window": 1000000, - "max_output": 32768, - "pricing": { - "input_tiers": [ - {"max_tokens": 128000, "rate": 0.11, "cached_input": 0.022}, - {"max_tokens": 256000, "rate": 0.33, "cached_input": 0.066}, - {"max_tokens": null, "rate": 0.67, "cached_input": 0.134} - ], - "output_tiers": [ - {"max_tokens": 128000, "rate": 0.28}, - {"max_tokens": 256000, "rate": 2.78}, - {"max_tokens": null, "rate": 6.67} - ], - "unit": "per_1m_tokens" - } - }, - { - "id": "qwen-flash", - "name": "Qwen Flash", - "is_reasoning": false, - "description": "Alibaba's Qwen Flash - fastest and most cost-efficient", - "context_window": 1000000, - "max_output": 32768, - "pricing": { - "input_tiers": [ - {"max_tokens": 128000, "rate": 0.021, "cached_input": 0.004}, - {"max_tokens": 256000, "rate": 0.083, "cached_input": 0.017}, - {"max_tokens": null, "rate": 0.17, "cached_input": 0.034} - ], - "output_tiers": [ - {"max_tokens": 128000, "rate": 0.21}, - {"max_tokens": 256000, "rate": 0.83}, - {"max_tokens": null, "rate": 1.67} - ], - "unit": "per_1m_tokens" - } } ], "dashscope-sg": [ @@ -912,15 +545,21 @@ "sdk": "openai", "base_url": "https://api.openai.com/v1", "env_key": "OPENAI_API_KEY", - "use_response_api": false + "use_response_api": false, + "byok_eligible": false, + "display_name": "OpenAI" }, "anthropic": { "sdk": "anthropic", - "env_key": "ANTHROPIC_API_KEY" + "env_key": "ANTHROPIC_API_KEY", + "byok_eligible": false, + "display_name": "Anthropic" }, "gemini": { "sdk": "gemini", - "env_key": "GEMINI_API_KEY" + "env_key": "GEMINI_API_KEY", + "byok_eligible": true, + "display_name": "Gemini" }, "lm-studio": { "sdk": "openai", @@ -947,7 +586,9 @@ "openrouter": { "sdk": "deepseek", "base_url": "https://openrouter.ai/api/v1", - "env_key": "OPENROUTER_API_KEY" + "env_key": "OPENROUTER_API_KEY", + "byok_eligible": false, + "display_name": "OpenRouter" }, "moonshot": { "sdk": "anthropic", @@ -958,7 +599,9 @@ "deepseek": { "sdk": "anthropic", "base_url": "https://api.deepseek.com/anthropic", - "env_key": "DEEPSEEK_API_KEY" + "env_key": "DEEPSEEK_API_KEY", + "byok_eligible": false, + "display_name": "DeepSeek" }, "qwen": { "sdk": "openai", @@ -1016,11 +659,6 @@ "base_url": "https://api.cerebras.ai/v1", "env_key": "CEREBRAS_API_KEY", "use_response_api": false - }, - "anthropic-proxy": { - "sdk": "anthropic", - "base_url": "http://113.249.103.63:3000", - "env_key": "ANTHROPIC_PROXY_API_KEY" } }, "infrastructure_pricing": { @@ -1035,7 +673,7 @@ "BochaSearchTool": { "credits_per_use": 8, "search_type": "ai_search", - "pricing_note": "0.06 RMB per call (~$0.0083 USD at 7.2 exchange rate, AI Search endpoint)" + "pricing_note": "0.06 RMB per call" } }, "credit_conversion": { diff --git a/src/server/app/api_keys.py b/src/server/app/api_keys.py new file mode 100644 index 00000000..0765e479 --- /dev/null +++ b/src/server/app/api_keys.py @@ -0,0 +1,164 @@ +""" +API Keys and Models Router. + +Endpoints: +- GET /api/v1/users/me/api-keys — Get BYOK config (masked keys) +- PUT /api/v1/users/me/api-keys — Update BYOK config +- DELETE /api/v1/users/me/api-keys/{prov} — Remove one provider key +- GET /api/v1/models — List available models by provider +""" + +import logging +from typing import Dict, Optional + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel, field_validator + +from src.server.utils.api import CurrentUserId +from src.server.database.api_keys import ( + get_user_api_keys, + set_byok_enabled, + upsert_api_key, + delete_api_key, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["API Keys"]) + +# Module-level cache for BYOK-eligible providers (loaded once on first access) +_BYOK_PROVIDERS_CACHE: list[str] | None = None + + +def _get_supported_providers() -> list[str]: + """Get BYOK-eligible providers from LLM manifest (cached at module level).""" + global _BYOK_PROVIDERS_CACHE + if _BYOK_PROVIDERS_CACHE is None: + from src.llms.llm import ModelConfig + + config = ModelConfig() + _BYOK_PROVIDERS_CACHE = config.get_byok_eligible_providers() + return _BYOK_PROVIDERS_CACHE + + +def _get_provider_display_names() -> dict[str, str]: + """Get display names for BYOK-eligible providers from manifest.""" + from src.llms.llm import ModelConfig + + config = ModelConfig() + names = {} + for p in _get_supported_providers(): + info = config.get_provider_info(p) + names[p] = info.get("display_name", p.title()) + return names + + +def _mask_key(key: str) -> str: + """Mask an API key: show first 3 + last 4 chars.""" + if not key or len(key) < 8: + return "****" + return f"{key[:3]}...{key[-4:]}" + + +def _format_response(byok_enabled: bool, keys: dict) -> dict: + """Build the public response shape (never exposes full keys).""" + display_names = _get_provider_display_names() + providers = [] + for p in _get_supported_providers(): + raw = keys.get(p) + providers.append({ + "provider": p, + "display_name": display_names.get(p, p.title()), + "has_key": bool(raw), + "masked_key": _mask_key(raw) if raw else None, + }) + return {"byok_enabled": byok_enabled, "providers": providers} + + +# ── BYOK Endpoints ────────────────────────────────────────────────────── + + +@router.get("/api/v1/users/me/api-keys") +async def get_api_keys(user_id: CurrentUserId): + """Get user's BYOK configuration (keys are masked).""" + data = await get_user_api_keys(user_id) + return _format_response(data["byok_enabled"], data["keys"]) + + +class UpdateApiKeysRequest(BaseModel): + byok_enabled: Optional[bool] = None + api_keys: Optional[Dict[str, Optional[str]]] = None + + @field_validator("api_keys") + @classmethod + def validate_api_keys(cls, v): + if v is None: + return v + for provider, key in v.items(): + if key is not None: + if len(key) < 10 or len(key) > 256: + raise ValueError(f"API key for {provider} must be 10-256 chars") + if not key.isascii(): + raise ValueError(f"API key for {provider} must be ASCII") + return v + + +@router.put("/api/v1/users/me/api-keys") +async def update_api_keys(body: UpdateApiKeysRequest, user_id: CurrentUserId): + """ + Update BYOK settings. + + - byok_enabled: toggle the global switch + - api_keys: { "openai": "sk-..." } to set, { "openai": null } to delete + """ + # Toggle BYOK if requested + if body.byok_enabled is not None: + await set_byok_enabled(user_id, body.byok_enabled) + + # Upsert / delete individual provider keys + if body.api_keys: + supported = _get_supported_providers() + for provider, key_value in body.api_keys.items(): + if provider not in supported: + raise HTTPException( + status_code=400, + detail=f"Unsupported provider: {provider}. Supported: {supported}", + ) + if key_value is None: + await delete_api_key(user_id, provider) + else: + await upsert_api_key(user_id, provider, key_value) + + # Return updated state + data = await get_user_api_keys(user_id) + return _format_response(data["byok_enabled"], data["keys"]) + + +@router.delete("/api/v1/users/me/api-keys/{provider}") +async def remove_api_key(provider: str, user_id: CurrentUserId): + """Remove one provider's API key.""" + supported = _get_supported_providers() + if provider not in supported: + raise HTTPException( + status_code=400, + detail=f"Unsupported provider: {provider}. Supported: {supported}", + ) + await delete_api_key(user_id, provider) + data = await get_user_api_keys(user_id) + return _format_response(data["byok_enabled"], data["keys"]) + + +# ── Models Endpoint ────────────────────────────────────────────────────── + + +@router.get("/api/v1/models") +async def list_models(): + """ + List all configured LLM models grouped by provider. + + No auth required — this is public configuration info. + """ + from src.llms.llm import get_configured_llm_models + + models = get_configured_llm_models() + return {"models": models} diff --git a/src/server/app/chat.py b/src/server/app/chat.py index 44fbd947..f0a9ec7e 100644 --- a/src/server/app/chat.py +++ b/src/server/app/chat.py @@ -69,7 +69,6 @@ build_skill_prefix_message, ) from src.server.utils.image_context import parse_image_contexts, inject_image_context -from src.server.utils.api import CurrentUserId from src.server.dependencies.usage_limits import ChatRateLimited from src.server.services.usage_limiter import UsageLimiter @@ -89,8 +88,98 @@ router = APIRouter(prefix="/api/v1/chat", tags=["Chat"]) +async def _resolve_byok_llm_client(user_id: str, model_name: str, byok_active: bool): + """ + If BYOK is active, look up the user's key for the model's provider + and return a fresh LLM client. Returns None if BYOK isn't applicable. + + Uses a single combined query (get_byok_key_for_provider) instead of + separate is_byok_active + get_key_for_provider calls. + """ + if not byok_active: + return None + + from src.server.database.api_keys import get_byok_key_for_provider + from src.llms.llm import LLM as LLMFactory, create_llm + + mc = LLMFactory.get_model_config() + model_info = mc.get_model_config(model_name) + if not model_info: + return None + + provider = model_info["provider"] + user_key = await get_byok_key_for_provider(user_id, provider) + if not user_key: + return None + + logger.info(f"[CHAT] Using BYOK key for provider={provider}") + return create_llm(model_name, api_key=user_key) + + +# Maps agent mode → (config field on llm, preference key in other_preference) +_MODE_MODEL_MAP = { + "ptc": ("name", "preferred_model"), + "flash": ("flash", "preferred_flash_model"), +} + + +async def _get_model_preference(user_id: str) -> dict: + """Return model preferences from other_preference (not agent_preference, which is dumped to agent context).""" + from src.server.database.user import get_user_preferences + + prefs = await get_user_preferences(user_id) + if not prefs: + return {} + return prefs.get("other_preference") or {} + + +async def _resolve_llm_config( + base_config, + user_id: str, + request_model: str | None, + byok_active: bool, + mode: str = "ptc", +): + """ + Resolve final LLM config with priority: + per-request model > user preferred model > default. + Then inject BYOK client if active. + + Mode determines which config field and preference key to use + (see _MODE_MODEL_MAP). Easy to extend for new modes. + """ + model_field, pref_key = _MODE_MODEL_MAP[mode] + config = base_config + + if request_model: + config = config.model_copy(deep=True) + setattr(config.llm, model_field, request_model) + config.llm_client = None + logger.info(f"[CHAT] Using per-request LLM model: {request_model}") + else: + model_pref = await _get_model_preference(user_id) + preferred = model_pref.get(pref_key) + if preferred: + config = config.model_copy(deep=True) + setattr(config.llm, model_field, preferred) + config.llm_client = None + logger.info(f"[CHAT] Using {pref_key}: {preferred}") + else: + logger.info(f"[CHAT] No {pref_key} set, using system default: {getattr(config.llm, model_field, None) or config.llm.name}") + + # BYOK injection — resolve the effective model from whichever field we just set + effective_model = getattr(config.llm, model_field, None) or config.llm.name + byok_client = await _resolve_byok_llm_client(user_id, effective_model, byok_active) + if byok_client: + if config is base_config: + config = config.model_copy(deep=True) + config.llm_client = byok_client + + return config + + @router.post("/stream") -async def chat_stream(request: ChatRequest, user_id: ChatRateLimited): +async def chat_stream(request: ChatRequest, auth: ChatRateLimited): """ Stream PTC agent responses as Server-Sent Events. @@ -107,6 +196,10 @@ async def chat_stream(request: ChatRequest, user_id: ChatRateLimited): Returns: StreamingResponse with SSE events """ + # Extract user_id and byok_active from auth result + user_id = auth.user_id + byok_active = auth.byok_active + # Determine agent mode agent_mode = request.agent_mode or "ptc" @@ -160,6 +253,7 @@ async def chat_stream(request: ChatRequest, user_id: ChatRateLimited): thread_id=thread_id, user_input=user_input, user_id=user_id, + byok_active=byok_active, ), media_type="text/event-stream", headers={ @@ -176,6 +270,7 @@ async def chat_stream(request: ChatRequest, user_id: ChatRateLimited): user_input=user_input, user_id=user_id, workspace_id=workspace_id, + byok_active=byok_active, ), media_type="text/event-stream", headers={ @@ -191,6 +286,7 @@ async def _astream_flash_workflow( thread_id: str, user_input: str, user_id: str, + byok_active: bool = False, ): """ Async generator that streams Flash agent workflow events. @@ -272,17 +368,10 @@ async def _astream_flash_workflow( # Build Flash Agent Graph # ===================================================================== - # Resolve LLM config for this request - config = setup.agent_config - if request.llm_model: - # Per-request LLM override takes precedence - config = config.model_copy(deep=True) - config.llm.flash = request.llm_model - logger.info( - f"[FLASH_CHAT] Using per-request LLM model: {request.llm_model}" - ) - elif config.llm.flash: - logger.info(f"[FLASH_CHAT] Using flash-specific LLM: {config.llm.flash}") + # Resolve LLM config for this request (model override + preferred + BYOK) + config = await _resolve_llm_config( + setup.agent_config, user_id, request.llm_model, byok_active, mode="flash" + ) # Build flash graph (no sandbox, no session) flash_graph = build_flash_graph( @@ -437,6 +526,7 @@ async def _astream_workflow( user_input: str, user_id: str, workspace_id: str, + byok_active: bool = False, ): """ Async generator that streams PTC agent workflow events. @@ -573,13 +663,10 @@ async def _astream_workflow( # Session and Graph Setup # ===================================================================== - # Resolve LLM config for this request - config = setup.agent_config - if request.llm_model: - config = config.model_copy(deep=True) - config.llm.name = request.llm_model - config.llm_client = None # Force rebuild with new model - logger.info(f"[PTC_CHAT] Using per-request LLM model: {request.llm_model}") + # Resolve LLM config for this request (model override + preferred + BYOK) + config = await _resolve_llm_config( + setup.agent_config, user_id, request.llm_model, byok_active, mode="ptc" + ) subagents = request.subagents_enabled or config.subagents_enabled sandbox_id = None diff --git a/src/server/app/setup.py b/src/server/app/setup.py index 5fe3837d..7d57d1d4 100644 --- a/src/server/app/setup.py +++ b/src/server/app/setup.py @@ -335,6 +335,7 @@ async def send_wrapper(message): from src.server.app.sec_proxy import router as sec_proxy_router from src.server.app.usage import router as usage_router from src.server.app.plans import router as plans_router +from src.server.app.api_keys import router as api_keys_router # Include all routers app.include_router(chat_router) # /api/v1/chat/* - Main chat endpoint @@ -354,4 +355,5 @@ async def send_wrapper(message): app.include_router(sec_proxy_router) # /api/v1/sec-proxy/* - SEC EDGAR document proxy app.include_router(usage_router) # /api/v1/usage/* - Usage limits and code redemption app.include_router(plans_router) # /api/v1/plans - Plan definitions (public) +app.include_router(api_keys_router) # /api/v1/users/me/api-keys + /api/v1/models - BYOK & model config app.include_router(health_router) # /health - Health check diff --git a/src/server/app/usage.py b/src/server/app/usage.py index 9cba3a93..4f1ed1c3 100644 --- a/src/server/app/usage.py +++ b/src/server/app/usage.py @@ -40,12 +40,17 @@ def _plan_obj(plan_info): 'rank': plan_info.rank, } + from src.server.database.api_keys import is_byok_active + + byok_enabled = await is_byok_active(user_id) + if not UsageLimiter.is_enabled(): return { 'limits_enabled': False, 'plan': _plan_obj(svc.get_default_plan()), 'credits': {'used': 0.0, 'limit': -1, 'remaining': -1}, 'workspaces': {'active': 0, 'limit': -1, 'remaining': -1}, + 'byok_enabled': byok_enabled, } plan = await UsageLimiter.get_user_plan(user_id) @@ -71,6 +76,7 @@ def _plan_obj(plan_info): 'limit': workspace_limit, 'remaining': workspace_remaining, }, + 'byok_enabled': byok_enabled, } diff --git a/src/server/app/users.py b/src/server/app/users.py index c28a3e3d..db12f36b 100644 --- a/src/server/app/users.py +++ b/src/server/app/users.py @@ -310,11 +310,14 @@ async def update_preferences( if not user: raise_not_found("User") - # Convert Pydantic models to dicts for JSONB storage - risk_pref = request.risk_preference.model_dump(exclude_none=True) if request.risk_preference else None - investment_pref = request.investment_preference.model_dump(exclude_none=True) if request.investment_preference else None - agent_pref = request.agent_preference.model_dump(exclude_none=True) if request.agent_preference else None - other_pref = request.other_preference.model_dump(exclude_none=True) if request.other_preference else None + # Convert Pydantic models to dicts for JSONB storage. + # Use exclude_unset=True (not exclude_none=True) so explicitly-sent null + # values are preserved — _split_updates_and_deletes uses None to signal + # key deletion from the JSONB column. + risk_pref = request.risk_preference.model_dump(exclude_unset=True) if request.risk_preference else None + investment_pref = request.investment_preference.model_dump(exclude_unset=True) if request.investment_preference else None + agent_pref = request.agent_preference.model_dump(exclude_unset=True) if request.agent_preference else None + other_pref = request.other_preference.model_dump(exclude_unset=True) if request.other_preference else None preferences = await upsert_user_preferences( user_id=user_id, diff --git a/src/server/database/api_keys.py b/src/server/database/api_keys.py new file mode 100644 index 00000000..ae6929fc --- /dev/null +++ b/src/server/database/api_keys.py @@ -0,0 +1,171 @@ +""" +Database CRUD for user API keys (BYOK support). + +Normalized schema: one row per (user_id, provider) in user_api_keys, +plus a byok_enabled boolean on the users table. + +All API keys are encrypted at rest using pgcrypto (pgp_sym_encrypt/decrypt). +Encryption is transparent to callers — functions accept and return plaintext strings. +""" + +import logging +import os +from typing import Any, Dict, Optional + +from psycopg.rows import dict_row + +from src.server.database.conversation import get_db_connection + +logger = logging.getLogger(__name__) + + +def _get_encryption_key() -> str: + """Return the symmetric encryption key for API key storage.""" + key = os.getenv("BYOK_ENCRYPTION_KEY") + if not key: + raise RuntimeError( + "BYOK_ENCRYPTION_KEY environment variable is not set. " + "Required for encrypting user API keys at rest." + ) + return key + + +async def get_user_api_keys(user_id: str) -> Dict[str, Any]: + """ + Get user's BYOK configuration: toggle + all provider keys (decrypted). + + Returns: + { byok_enabled: bool, keys: { provider: api_key_plaintext, ... } } + """ + enc_key = _get_encryption_key() + async with get_db_connection() as conn: + async with conn.cursor(row_factory=dict_row) as cur: + # Fetch byok toggle from users table + await cur.execute( + "SELECT byok_enabled FROM users WHERE user_id = %s", + (user_id,), + ) + user_row = await cur.fetchone() + byok_enabled = bool(user_row["byok_enabled"]) if user_row else False + + # Fetch all provider keys (decrypted) + await cur.execute( + "SELECT provider, pgp_sym_decrypt(api_key, %s) AS api_key " + "FROM user_api_keys WHERE user_id = %s ORDER BY provider", + (enc_key, user_id), + ) + rows = await cur.fetchall() + keys = {row["provider"]: row["api_key"] for row in rows} + + return {"byok_enabled": byok_enabled, "keys": keys} + + +async def set_byok_enabled(user_id: str, enabled: bool) -> bool: + """ + Set the global BYOK toggle on the users table. + + Returns: + The new byok_enabled value. + """ + async with get_db_connection() as conn: + async with conn.cursor(row_factory=dict_row) as cur: + await cur.execute( + "UPDATE users SET byok_enabled = %s, updated_at = NOW() WHERE user_id = %s RETURNING byok_enabled", + (enabled, user_id), + ) + result = await cur.fetchone() + logger.info(f"[api_keys_db] set_byok_enabled user_id={user_id} enabled={enabled}") + return bool(result["byok_enabled"]) if result else False + + +async def upsert_api_key(user_id: str, provider: str, api_key: str) -> None: + """ + Insert or update a single provider key (encrypted). + """ + enc_key = _get_encryption_key() + async with get_db_connection() as conn: + async with conn.cursor() as cur: + await cur.execute( + """ + INSERT INTO user_api_keys (user_id, provider, api_key, created_at, updated_at) + VALUES (%s, %s, pgp_sym_encrypt(%s, %s), NOW(), NOW()) + ON CONFLICT (user_id, provider) DO UPDATE + SET api_key = EXCLUDED.api_key, + updated_at = NOW() + """, + (user_id, provider, api_key, enc_key), + ) + logger.info(f"[api_keys_db] upsert_key user_id={user_id} provider={provider}") + + +async def delete_api_key(user_id: str, provider: str) -> None: + """ + Remove one provider key. + """ + async with get_db_connection() as conn: + async with conn.cursor() as cur: + await cur.execute( + "DELETE FROM user_api_keys WHERE user_id = %s AND provider = %s", + (user_id, provider), + ) + logger.info(f"[api_keys_db] delete_key user_id={user_id} provider={provider}") + + +async def get_key_for_provider(user_id: str, provider: str) -> Optional[str]: + """ + Quick lookup: return the decrypted API key for a specific provider, or None. + """ + enc_key = _get_encryption_key() + async with get_db_connection() as conn: + async with conn.cursor(row_factory=dict_row) as cur: + await cur.execute( + "SELECT pgp_sym_decrypt(api_key, %s) AS api_key " + "FROM user_api_keys WHERE user_id = %s AND provider = %s", + (enc_key, user_id, provider), + ) + row = await cur.fetchone() + return row["api_key"] if row else None + + +async def is_byok_active(user_id: str) -> bool: + """ + Quick check: is BYOK enabled AND does the user have at least one key? + """ + async with get_db_connection() as conn: + async with conn.cursor(row_factory=dict_row) as cur: + await cur.execute( + """ + SELECT 1 FROM users u + WHERE u.user_id = %s + AND u.byok_enabled = TRUE + AND EXISTS ( + SELECT 1 FROM user_api_keys k WHERE k.user_id = u.user_id + ) + LIMIT 1 + """, + (user_id,), + ) + return (await cur.fetchone()) is not None + + +async def get_byok_key_for_provider(user_id: str, provider: str) -> Optional[str]: + """ + Combined query: return the decrypted API key only if BYOK is enabled. + + Returns None if BYOK is disabled OR no key exists for this provider. + Saves a round-trip vs calling is_byok_active() + get_key_for_provider() separately. + """ + enc_key = _get_encryption_key() + async with get_db_connection() as conn: + async with conn.cursor(row_factory=dict_row) as cur: + await cur.execute( + """ + SELECT pgp_sym_decrypt(k.api_key, %s) AS api_key + FROM user_api_keys k + JOIN users u ON u.user_id = k.user_id + WHERE k.user_id = %s AND k.provider = %s AND u.byok_enabled = TRUE + """, + (enc_key, user_id, provider), + ) + row = await cur.fetchone() + return row["api_key"] if row else None diff --git a/src/server/dependencies/usage_limits.py b/src/server/dependencies/usage_limits.py index e937742e..8719ec2f 100644 --- a/src/server/dependencies/usage_limits.py +++ b/src/server/dependencies/usage_limits.py @@ -8,6 +8,7 @@ Both are complete no-ops when auth is disabled. """ +from dataclasses import dataclass from typing import Annotated from fastapi import Depends, HTTPException @@ -16,19 +17,52 @@ from src.server.services.usage_limiter import UsageLimiter +@dataclass +class ChatAuthResult: + """Result from chat rate-limit dependency, carrying BYOK status to avoid re-querying.""" + user_id: str + byok_active: bool = False + + async def enforce_chat_limit( user_id: str = Depends(get_current_user_id), -) -> str: +) -> ChatAuthResult: """ FastAPI dependency: enforce daily credit limit + burst guard. Layer 1: DB credit check (SUM total_credits today vs tier daily_credits) Layer 2: Redis burst guard (concurrent in-flight request cap) - Returns user_id on success, raises HTTPException(429) if over limit. + BYOK users bypass the credit check but still face burst guard. + + Returns ChatAuthResult on success, raises HTTPException(429) if over limit. """ if not UsageLimiter.is_enabled(): - return user_id + return ChatAuthResult(user_id=user_id) + + # Check BYOK status once — reused downstream via ChatAuthResult + from src.server.database.api_keys import is_byok_active + + byok = await is_byok_active(user_id) + + if byok: + # BYOK bypasses credit limit, but still enforce burst guard + plan = await UsageLimiter.get_user_plan(user_id) + if plan.max_concurrent_requests != -1: + burst_result = await UsageLimiter._check_burst_guard( + user_id, plan.max_concurrent_requests + ) + if not burst_result['allowed']: + raise HTTPException( + status_code=429, + detail={ + 'message': 'Too many concurrent requests', + 'type': 'burst_limit', + 'retry_after': 5, + }, + headers={'Retry-After': '5'}, + ) + return ChatAuthResult(user_id=user_id, byok_active=True) result = await UsageLimiter.check_chat_limit(user_id) @@ -59,7 +93,7 @@ async def enforce_chat_limit( }, ) - return user_id + return ChatAuthResult(user_id=user_id) async def enforce_workspace_limit( @@ -96,5 +130,5 @@ async def enforce_workspace_limit( # Annotated types for cleaner endpoint signatures -ChatRateLimited = Annotated[str, Depends(enforce_chat_limit)] +ChatRateLimited = Annotated[ChatAuthResult, Depends(enforce_chat_limit)] WorkspaceLimitCheck = Annotated[str, Depends(enforce_workspace_limit)] From 11a2c7d13e422ccd32a85ef486369d44b8390704 Mon Sep 17 00:00:00 2001 From: Alan Chen Date: Wed, 11 Feb 2026 12:12:37 -0500 Subject: [PATCH 6/6] feat(web): add model preferences, BYOK management, and rate limit UX --- .../Dashboard/components/UserConfigPanel.jsx | 316 +++++++++++++++++- web/src/pages/Dashboard/utils/api.js | 24 ++ .../TradingCenter/components/TradingPanel.jsx | 19 +- .../TradingCenter/hooks/useTradingChat.js | 56 ++-- web/src/pages/TradingCenter/utils/api.js | 9 + 5 files changed, 378 insertions(+), 46 deletions(-) diff --git a/web/src/pages/Dashboard/components/UserConfigPanel.jsx b/web/src/pages/Dashboard/components/UserConfigPanel.jsx index 865e9722..1da64faf 100644 --- a/web/src/pages/Dashboard/components/UserConfigPanel.jsx +++ b/web/src/pages/Dashboard/components/UserConfigPanel.jsx @@ -1,7 +1,7 @@ import React, { useState, useEffect, useRef } from 'react'; -import { X, User, LogOut } from 'lucide-react'; +import { X, User, LogOut, Eye, EyeOff, Trash2, HelpCircle } from 'lucide-react'; import { Input } from '../../../components/ui/input'; -import { updateCurrentUser, getCurrentUser, updatePreferences, getPreferences, uploadAvatar, redeemCode, getUsageStatus } from '../utils/api'; +import { updateCurrentUser, getCurrentUser, updatePreferences, getPreferences, uploadAvatar, redeemCode, getUsageStatus, getAvailableModels, getUserApiKeys, updateUserApiKeys, deleteUserApiKey } from '../utils/api'; import { useAuth } from '../../../contexts/AuthContext'; import ConfirmDialog from './ConfirmDialog'; @@ -38,9 +38,22 @@ function UserConfigPanel({ isOpen, onClose }) { const [usage, setUsage] = useState(null); + // Model tab state + const [availableModels, setAvailableModels] = useState({}); + const [preferredModel, setPreferredModel] = useState(''); + const [preferredFlashModel, setPreferredFlashModel] = useState(''); + const [byokEnabled, setByokEnabled] = useState(false); + const [byokProviders, setByokProviders] = useState([]); + const [keyInputs, setKeyInputs] = useState({}); + const [visibleKeys, setVisibleKeys] = useState({}); + const [deletingProvider, setDeletingProvider] = useState(null); + const [modelTabError, setModelTabError] = useState(null); + const [modelSaveSuccess, setModelSaveSuccess] = useState(false); + const [isSubmitting, setIsSubmitting] = useState(false); const [isLoading, setIsLoading] = useState(false); const [error, setError] = useState(null); + const [saveSuccess, setSaveSuccess] = useState(false); const [showLogoutConfirm, setShowLogoutConfirm] = useState(false); const timezones = [ @@ -91,6 +104,13 @@ function UserConfigPanel({ isOpen, onClose }) { } }, [isOpen]); + // Load model tab data lazily when tab is selected + useEffect(() => { + if (isOpen && activeTab === 'model') { + loadModelTabData(); + } + }, [isOpen, activeTab]); + const loadUserData = async () => { try { const userData = await getCurrentUser(); @@ -130,6 +150,81 @@ function UserConfigPanel({ isOpen, onClose }) { } }; + const loadModelTabData = async () => { + setModelTabError(null); + try { + const [modelsRes, keysRes, prefsRes] = await Promise.all([ + getAvailableModels(), + getUserApiKeys(), + getPreferences(), + ]); + setAvailableModels(modelsRes?.models || {}); + setByokEnabled(keysRes?.byok_enabled || false); + setByokProviders(keysRes?.providers || []); + setPreferredModel(prefsRes?.other_preference?.preferred_model || ''); + setPreferredFlashModel(prefsRes?.other_preference?.preferred_flash_model || ''); + } catch { + setModelTabError('Failed to load model settings'); + } + }; + + const handleModelTabSave = async () => { + setModelTabError(null); + setModelSaveSuccess(false); + setIsSubmitting(true); + try { + // 1. Save model preferences + await updatePreferences({ + other_preference: { + preferred_model: preferredModel || null, + preferred_flash_model: preferredFlashModel || null, + }, + }); + + // 2. Save any pending API key inputs + const pendingKeys = Object.entries(keyInputs).filter(([, v]) => v?.trim()); + for (const [provider, key] of pendingKeys) { + const result = await updateUserApiKeys({ api_keys: { [provider]: key.trim() } }); + setByokProviders(result.providers); + } + if (pendingKeys.length > 0) { + setKeyInputs({}); + } + + setModelSaveSuccess(true); + setTimeout(() => setModelSaveSuccess(false), 3000); + } catch { + setModelTabError('Failed to save settings'); + } finally { + setIsSubmitting(false); + } + }; + + const handleByokToggle = async () => { + setModelTabError(null); + const newValue = !byokEnabled; + try { + const result = await updateUserApiKeys({ byok_enabled: newValue }); + setByokEnabled(result.byok_enabled); + setByokProviders(result.providers); + } catch { + setModelTabError('Failed to toggle BYOK'); + } + }; + + const handleDeleteProviderKey = async (provider) => { + setDeletingProvider(provider); + setModelTabError(null); + try { + const result = await deleteUserApiKey(provider); + setByokProviders(result.providers); + } catch { + setModelTabError(`Failed to delete ${provider} key`); + } finally { + setDeletingProvider(null); + } + }; + const handleAvatarChange = async (e) => { const file = e.target.files[0]; if (!file) return; @@ -149,6 +244,7 @@ function UserConfigPanel({ isOpen, onClose }) { e.preventDefault(); setIsSubmitting(true); setError(null); + setSaveSuccess(false); try { const userData = {}; if (name.trim()) userData.name = name.trim(); @@ -156,9 +252,9 @@ function UserConfigPanel({ isOpen, onClose }) { if (locale) userData.locale = locale; if (Object.keys(userData).length > 0) { await updateCurrentUser(userData); - refreshUser(); } - onClose(); + setSaveSuccess(true); + setTimeout(() => setSaveSuccess(false), 3000); } catch (err) { setError(err.message || 'Failed to update user information'); } finally { @@ -170,6 +266,7 @@ function UserConfigPanel({ isOpen, onClose }) { e.preventDefault(); setIsSubmitting(true); setError(null); + setSaveSuccess(false); try { const preferences = {}; if (riskTolerance) preferences.risk_preference = { risk_tolerance: riskTolerance }; @@ -182,9 +279,9 @@ function UserConfigPanel({ isOpen, onClose }) { if (Object.keys(preferences).length > 0) { await updatePreferences(preferences); await updateCurrentUser({ onboarding_completed: true }); - refreshUser(); } - onClose(); + setSaveSuccess(true); + setTimeout(() => setSaveSuccess(false), 3000); } catch (err) { setError(err.message || 'Failed to update preferences'); } finally { @@ -224,8 +321,17 @@ function UserConfigPanel({ isOpen, onClose }) { ]; const getPlanBadgeStyle = (rank) => PLAN_BADGE_COLORS[Math.min(rank, PLAN_BADGE_COLORS.length - 1)]; + // Prevent Enter key in text inputs from submitting the enclosing
. + // Only the explicit submit button should trigger form submission. + const preventEnterSubmit = (e) => { + if (e.key === 'Enter' && e.target.tagName === 'INPUT' && e.target.type !== 'submit') { + e.preventDefault(); + } + }; + const handleClose = () => { setError(null); + setSaveSuccess(false); setRedeemError(null); setRedeemSuccess(null); onClose(); @@ -282,6 +388,17 @@ function UserConfigPanel({ isOpen, onClose }) { > Preferences +
{isLoading && ( @@ -291,7 +408,7 @@ function UserConfigPanel({ isOpen, onClose }) { )} {!isLoading && activeTab === 'userInfo' && ( - +
- Daily Credits + + Daily Credits + {usage.byok_enabled && ( + + BYOK + + )} + {usage.credits.limit === -1 ? 'Unlimited' @@ -498,10 +625,13 @@ function UserConfigPanel({ isOpen, onClose }) { > Logout -
+
+ {saveSuccess && ( + Saved + )}
@@ -518,7 +648,7 @@ function UserConfigPanel({ isOpen, onClose }) { )} {!isLoading && activeTab === 'preferences' && ( - +
setter(e.target.value)} + className="w-full rounded-md px-3 py-2 text-sm" + style={{ + backgroundColor: 'var(--color-bg-card)', + border: '1px solid var(--color-border-muted)', + color: 'var(--color-text-primary)', + }} + disabled={isSubmitting} + > + + {Object.entries(availableModels).map(([provider, models]) => ( + + {models.map((m) => ( + + ))} + + ))} + +
+ ))} + +
+ + {/* Section 2: BYOK */} +
+
+
+
+ +
+ +
+ Your API keys are stored using AES encryption and are never visible in plaintext. If you choose to delete a key, it is permanently removed from our records. +
+
+
+

+ Provide your own API keys to bypass credit limits. +

+
+ +
+ + {byokEnabled && ( +
+ {byokProviders.map((prov) => ( +
+ + {prov.display_name || prov.provider} + +
+ setKeyInputs((prev) => ({ ...prev, [prov.provider]: e.target.value }))} + placeholder={prov.has_key ? prov.masked_key : 'Enter API key...'} + className="w-full rounded-md px-3 py-1.5 pr-8 text-sm" + style={{ + backgroundColor: 'var(--color-bg-card)', + border: '1px solid var(--color-border-muted)', + color: 'var(--color-text-primary)', + }} + /> + +
+ {prov.has_key && ( + + )} +
+ ))} +
+ )} +
+ + {modelTabError && ( +
+

{modelTabError}

+
+ )} + +
+ {modelSaveSuccess && ( + Saved + )} + + +
+
+ )}
diff --git a/web/src/pages/Dashboard/utils/api.js b/web/src/pages/Dashboard/utils/api.js index de662529..34e38c5c 100644 --- a/web/src/pages/Dashboard/utils/api.js +++ b/web/src/pages/Dashboard/utils/api.js @@ -338,6 +338,30 @@ export const getPortfolio = portfolioApi.listPortfolio; /** Add portfolio holding. Payload: symbol, instrument_type, quantity, average_cost?, ... */ export const addPortfolioHolding = portfolioApi.addPortfolioHolding; +// --- Models --- + +export async function getAvailableModels() { + const { data } = await api.get('/api/v1/models'); + return data; +} + +// --- BYOK API Keys --- + +export async function getUserApiKeys() { + const { data } = await api.get('/api/v1/users/me/api-keys'); + return data; +} + +export async function updateUserApiKeys(payload) { + const { data } = await api.put('/api/v1/users/me/api-keys', payload); + return data; +} + +export async function deleteUserApiKey(provider) { + const { data } = await api.delete(`/api/v1/users/me/api-keys/${provider}`); + return data; +} + // --- InfoFlow (content feed) --- /** diff --git a/web/src/pages/TradingCenter/components/TradingPanel.jsx b/web/src/pages/TradingCenter/components/TradingPanel.jsx index d96cbb32..d608bda0 100644 --- a/web/src/pages/TradingCenter/components/TradingPanel.jsx +++ b/web/src/pages/TradingCenter/components/TradingPanel.jsx @@ -79,15 +79,26 @@ const TradingPanel = ({ messages = [], isLoading = false, error = null }) => { ) : (
({ - ...msg, - error: error && msg.id === messages[messages.length - 1]?.id ? error : msg.error - }))} + messages={messages} hideAvatar compactToolCalls onOpenSubagentTask={() => {}} onOpenFile={() => {}} /> + {error && ( +
+ {error} +
+ )}
)} diff --git a/web/src/pages/TradingCenter/hooks/useTradingChat.js b/web/src/pages/TradingCenter/hooks/useTradingChat.js index 3f69d1f1..bdb34e71 100644 --- a/web/src/pages/TradingCenter/hooks/useTradingChat.js +++ b/web/src/pages/TradingCenter/hooks/useTradingChat.js @@ -520,32 +520,19 @@ export function useTradingChat() { // Flush any remaining batched updates flushUpdates(); - // Mark message as not streaming - setMessages((prev) => - prev.map((msg) => { - if (msg.id !== assistantMessageId) return msg; - return { - ...msg, - isStreaming: false, - }; - }) - ); - - // Only set error if we haven't received any events - // If we received events, the stream might have just been interrupted but we have partial data - if (!hasReceivedEvents) { - setError(err.message || 'Failed to send message'); - setMessages((prev) => - prev.map((msg) => { - if (msg.id !== assistantMessageId) return msg; - return { - ...msg, - error: err.message || 'Failed to send message', - }; - }) - ); + // Handle rate limit (429) — show friendly message and remove empty assistant placeholder + if (err.status === 429) { + const info = err.rateLimitInfo || {}; + const limitMsg = info.type === 'credit_limit' + ? `Daily credit limit reached (${info.used_credits}/${info.credit_limit} credits). Resets at midnight UTC.` + : info.type === 'burst_limit' + ? 'Too many concurrent requests. Please wait a moment.' + : info.message || 'Rate limit exceeded. Please try again later.'; + setError(limitMsg); + // Remove the empty assistant placeholder — no content to show + setMessages((prev) => prev.filter((msg) => msg.id !== assistantMessageId)); } else { - // We received some events, so mark as complete even if stream was interrupted + // Mark message as not streaming setMessages((prev) => prev.map((msg) => { if (msg.id !== assistantMessageId) return msg; @@ -555,8 +542,23 @@ export function useTradingChat() { }; }) ); - if (process.env.NODE_ENV === 'development') { - console.warn('[TradingChat] Stream interrupted but received partial data, marking as complete'); + + // Only set error if we haven't received any events + if (!hasReceivedEvents) { + setError(err.message || 'Failed to send message'); + setMessages((prev) => + prev.map((msg) => { + if (msg.id !== assistantMessageId) return msg; + return { + ...msg, + error: err.message || 'Failed to send message', + }; + }) + ); + } else { + if (process.env.NODE_ENV === 'development') { + console.warn('[TradingChat] Stream interrupted but received partial data, marking as complete'); + } } } } finally { diff --git a/web/src/pages/TradingCenter/utils/api.js b/web/src/pages/TradingCenter/utils/api.js index d2d2e24c..62917304 100644 --- a/web/src/pages/TradingCenter/utils/api.js +++ b/web/src/pages/TradingCenter/utils/api.js @@ -450,6 +450,15 @@ async function streamFetch(url, opts, onEvent) { } if (!res.ok) { + // Handle 429 (rate limit) with structured detail + if (res.status === 429) { + let detail = {}; + try { detail = await res.json(); } catch { /* ignore */ } + const err = new Error(detail?.detail?.message || 'Rate limit exceeded'); + err.status = 429; + err.rateLimitInfo = detail?.detail || {}; + throw err; + } const errorText = await res.text().catch(() => 'Unknown error'); throw new Error(`HTTP error! status: ${res.status}, message: ${errorText}`); }