diff --git a/alembic/versions/a58395ea1b22_add_balance_system.py b/alembic/versions/a58395ea1b22_add_balance_system.py new file mode 100644 index 0000000..19c34d0 --- /dev/null +++ b/alembic/versions/a58395ea1b22_add_balance_system.py @@ -0,0 +1,35 @@ +"""add balance system + +Revision ID: a58395ea1b22 +Revises: c9f3e548adef +Create Date: 2025-08-20 22:00:45.743308 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'a58395ea1b22' +down_revision = 'c9f3e548adef' +branch_labels = None +depends_on = None + +def upgrade() -> None: + op.create_table( + 'wallets', + sa.Column('account_id', sa.BigInteger(), nullable=False), + sa.Column('currency', sa.CHAR(length=3), nullable=False, server_default='USD'), + sa.Column('balance', sa.DECIMAL(precision=20, scale=6), nullable=False, server_default='0'), + sa.Column('blocked', sa.Boolean(), nullable=False, server_default='FALSE'), + sa.Column('version', sa.BigInteger(), nullable=False, server_default='0'), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')), + sa.PrimaryKeyConstraint('account_id'), + sa.ForeignKeyConstraint(['account_id'], ['users.id'], ondelete='CASCADE') + ) + op.add_column('provider_keys', sa.Column('billable', sa.Boolean(), nullable=False, server_default='FALSE')) + +def downgrade() -> None: + op.drop_table('wallets') + op.drop_column('provider_keys', 'billable') diff --git a/app/api/routes/wallet.py b/app/api/routes/wallet.py new file mode 100644 index 0000000..70c4ada --- /dev/null +++ b/app/api/routes/wallet.py @@ -0,0 +1,37 @@ +from decimal import Decimal +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession +from pydantic import BaseModel + +from app.api.dependencies import get_current_active_user, get_current_active_user_from_clerk +from app.core.database import get_async_db +from app.models.user import User +from app.services.wallet_service import WalletService + +router = APIRouter() + +class WalletResponse(BaseModel): + balance: Decimal + blocked: bool + currency: str + +@router.get("/balance", response_model=WalletResponse) +async def get_wallet_balance( + user: User = Depends(get_current_active_user), + db: AsyncSession = Depends(get_async_db) +): + """Get current wallet balance""" + wallet = await WalletService.get(db, user.id) + + if not wallet: + await WalletService.ensure_wallet(db, user.id) + return WalletResponse(balance=Decimal("0"), blocked=False, currency="USD") + + return WalletResponse(**wallet) + +@router.get("/balance/clerk", response_model=WalletResponse) +async def get_wallet_balance_clerk( + user: User = Depends(get_current_active_user_from_clerk), + db: AsyncSession = Depends(get_async_db) +): + return await get_wallet_balance(user, db) diff --git a/app/api/routes/webhooks.py b/app/api/routes/webhooks.py index 2dac064..f234872 100644 --- a/app/api/routes/webhooks.py +++ b/app/api/routes/webhooks.py @@ -13,13 +13,15 @@ from app.core.security import generate_forge_api_key from app.models.user import User from app.services.provider_service import create_default_tensorblock_provider_for_user +from app.services.wallet_service import WalletService logger = get_logger(name="webhooks") router = APIRouter() -# Clerk webhook signing secret for verifying webhook authenticity +# Webhook signing secrets for verifying webhook authenticity CLERK_WEBHOOK_SECRET = os.getenv("CLERK_WEBHOOK_SECRET", "") +STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "") @router.post("/clerk") @@ -238,3 +240,137 @@ async def handle_user_deleted(event_data: dict, db: AsyncSession): await db.rollback() logger.error(f"Failed to delete user from webhook: {e}", exc_info=True) raise + + +@router.post("/stripe") +async def stripe_webhook_handler(request: Request, db: AsyncSession = Depends(get_async_db)): + """ + Handle Stripe webhooks for payment events. + + Key events to handle: + - payment_intent.succeeded: Credit wallet balance + - payment_intent.payment_failed: Log failed payment + - invoice.payment_failed: Handle subscription payment failure + """ + # Get the request body and signature + payload = await request.body() + sig_header = request.headers.get("stripe-signature") + + if not sig_header: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing Stripe signature header" + ) + + # NOTE: For production, you would verify the webhook signature here + # This is a placeholder for the Stripe webhook verification + # Example: + # import stripe + # try: + # event = stripe.Webhook.construct_event( + # payload, sig_header, STRIPE_WEBHOOK_SECRET + # ) + # except ValueError: + # raise HTTPException(status_code=400, detail="Invalid payload") + # except stripe.error.SignatureVerificationError: + # raise HTTPException(status_code=400, detail="Invalid signature") + + try: + # For now, parse as JSON (would use verified event in production) + event_data = json.loads(payload) + event_type = event_data.get("type") + + logger.info(f"Received Stripe webhook: {event_type}") + + # Handle different event types + if event_type == "payment_intent.succeeded": + await handle_payment_succeeded(event_data, db) + elif event_type == "payment_intent.payment_failed": + await handle_payment_failed(event_data, db) + elif event_type == "invoice.payment_failed": + await handle_invoice_payment_failed(event_data, db) + else: + logger.info(f"Unhandled Stripe event type: {event_type}") + + return {"status": "success", "message": f"Event {event_type} processed"} + + except json.JSONDecodeError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid JSON payload" + ) + except Exception as e: + logger.error(f"Error processing Stripe webhook: {e}", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error processing webhook: {str(e)}", + ) + + +async def handle_payment_succeeded(event_data: dict, db: AsyncSession): + """Handle successful payment - credit wallet balance""" + try: + payment_intent = event_data.get("data", {}).get("object", {}) + amount = payment_intent.get("amount", 0) # Amount in cents + currency = payment_intent.get("currency", "usd").upper() + customer_id = payment_intent.get("customer") + + # Convert cents to dollars for USD + if currency == "USD": + amount_decimal = amount / 100.0 + else: + amount_decimal = amount # Handle other currencies as needed + + # TODO: Map customer_id to user_id + # For now, this is a placeholder - you'd need to implement customer mapping + # user_id = get_user_id_from_stripe_customer(customer_id) + + logger.info(f"Payment succeeded: {amount_decimal} {currency} for customer {customer_id}") + + # Uncomment when customer mapping is implemented: + # await WalletService.adjust( + # db, + # user_id, + # amount_decimal, + # f"deposit:stripe:{payment_intent.get('id')}", + # currency + # ) + + except Exception as e: + logger.error(f"Failed to process payment success: {e}", exc_info=True) + raise + + +async def handle_payment_failed(event_data: dict, db: AsyncSession): + """Handle failed payment""" + try: + payment_intent = event_data.get("data", {}).get("object", {}) + customer_id = payment_intent.get("customer") + + logger.warning(f"Payment failed for customer {customer_id}") + + # TODO: Implement failure handling logic + # - Notify user + # - Update payment status + # - Handle retry logic + + except Exception as e: + logger.error(f"Failed to process payment failure: {e}", exc_info=True) + raise + + +async def handle_invoice_payment_failed(event_data: dict, db: AsyncSession): + """Handle failed invoice payment - may need to block account""" + try: + invoice = event_data.get("data", {}).get("object", {}) + customer_id = invoice.get("customer") + + logger.warning(f"Invoice payment failed for customer {customer_id}") + + # TODO: Map customer_id to user_id and potentially block account + # user_id = get_user_id_from_stripe_customer(customer_id) + # await WalletService.set_blocked(db, user_id, True) + + except Exception as e: + logger.error(f"Failed to process invoice payment failure: {e}", exc_info=True) + raise diff --git a/app/main.py b/app/main.py index 4f1a674..40d1b72 100644 --- a/app/main.py +++ b/app/main.py @@ -18,6 +18,7 @@ statistic, stats, users, + wallet, webhooks, ) from app.core.database import engine @@ -167,6 +168,7 @@ def create_app() -> FastAPI: v1_router.include_router(api_keys.router, prefix="/api-keys", tags=["api-keys"]) v1_router.include_router(proxy.router, tags=["proxy"]) v1_router.include_router(stats.router, prefix="/stats", tags=["stats"]) + v1_router.include_router(wallet.router, prefix="/wallet", tags=["wallet"]) v1_router.include_router(webhooks.router, prefix="/webhooks", tags=["webhooks"]) v1_router.include_router(statistic.router, prefix='/statistic', tags=["statistic"]) # Claude Code compatible API endpoints diff --git a/app/models/provider_key.py b/app/models/provider_key.py index 5300190..7762c70 100644 --- a/app/models/provider_key.py +++ b/app/models/provider_key.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, ForeignKey, Integer, String, JSON +from sqlalchemy import Column, ForeignKey, Integer, String, JSON, Boolean from sqlalchemy.orm import relationship from app.models.forge_api_key import forge_api_key_provider_scope_association @@ -19,6 +19,7 @@ class ProviderKey(BaseModel): String, nullable=True ) # Allow custom base URLs for some providers model_mapping = Column(JSON, nullable=True) # JSON dict for model name mappings + billable = Column(Boolean, nullable=False, default=False) # Relationship to ForgeApiKeys that have this provider key in their scope scoped_forge_api_keys = relationship( diff --git a/app/models/user.py b/app/models/user.py index ec09d54..4d2d1bc 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -32,5 +32,6 @@ class User(Base): provider_keys = relationship( "ProviderKey", back_populates="user", cascade="all, delete-orphan" ) + wallet = relationship("Wallet", back_populates="user", uselist=False) # Optional: Add relationship to ApiRequestLog if needed # api_logs = relationship("ApiRequestLog") diff --git a/app/models/wallet.py b/app/models/wallet.py new file mode 100644 index 0000000..a1e3a57 --- /dev/null +++ b/app/models/wallet.py @@ -0,0 +1,17 @@ +from datetime import datetime, UTC +from sqlalchemy import Column, BigInteger, CHAR, DECIMAL, Boolean, DateTime, ForeignKey +from sqlalchemy.orm import relationship +from .base import Base + +class Wallet(Base): + __tablename__ = "wallets" + + account_id = Column(BigInteger, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True) + currency = Column(CHAR(3), nullable=False, default='USD') + balance = Column(DECIMAL(20, 6), nullable=False, default=0) + blocked = Column(Boolean, nullable=False, default=False) + version = Column(BigInteger, nullable=False, default=0) + created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.now(UTC)) + updated_at = Column(DateTime(timezone=True), nullable=False, default=datetime.now(UTC)) + + user = relationship("User", back_populates="wallet") diff --git a/app/services/provider_service.py b/app/services/provider_service.py index b4d5235..1ef8af3 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -6,7 +6,6 @@ import time from collections.abc import AsyncGenerator from typing import Any, ClassVar - from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -20,6 +19,7 @@ ) from app.models.user import User from app.core.database import get_db_session +from app.services.wallet_service import WalletService from .providers.adapter_factory import ProviderAdapterFactory from .providers.base import ProviderAdapter @@ -597,6 +597,7 @@ async def process_request( # Process the request through the adapter usage_tracker_id = None if self.api_key_id is not None and provider_key_id is not None: + await WalletService.wallet_precheck(self.user_id, self.db, provider_key_id) usage_tracker_id = await UsageTrackerService.start_tracking_usage( db=self.db, user_id=self.user_id, @@ -854,6 +855,7 @@ async def create_default_tensorblock_provider_for_user( encrypted_api_key=encrypt_api_key(serialized_api_key_config), user_id=user_id, base_url=tensorblock_base_url, + billable=True, model_mapping=None, # TensorBlock adapter handles model mapping internally ) diff --git a/app/services/providers/usage_tracker_service.py b/app/services/providers/usage_tracker_service.py index c2fb8a7..187f0e5 100644 --- a/app/services/providers/usage_tracker_service.py +++ b/app/services/providers/usage_tracker_service.py @@ -10,6 +10,7 @@ from app.core.logger import get_logger from app.models.usage_tracker import UsageTracker from app.services.pricing_service import PricingService +from app.services.wallet_service import WalletService logger = get_logger(name="usage_tracker") @@ -80,6 +81,22 @@ async def update_usage_tracker( usage_tracker.cost = price_info['total_cost'] usage_tracker.currency = price_info['currency'] usage_tracker.pricing_source = price_info['pricing_source'] + + # Deduct from wallet balance if the provider is not free + if price_info['total_cost'] and price_info['total_cost'] > 0 and usage_tracker.provider_key.billable: + try: + result = await WalletService.adjust( + db, + usage_tracker.user_id, + -price_info['total_cost'], + f"usage:{usage_tracker.endpoint}", + price_info['currency'] + ) + if not result.get("success"): + logger.warning(f"Failed to deduct from wallet for user {usage_tracker.user_id}: {result.get('reason')}") + except Exception as wallet_err: + logger.exception(f"Wallet deduction failed for user {usage_tracker.user_id}: {wallet_err}") + await db.commit() logger.debug(f"Updated usage tracker {usage_tracker_id} with input_tokens {input_tokens}, output_tokens {output_tokens}, cached_tokens {cached_tokens}, reasoning_tokens {reasoning_tokens}") except NoResultFound: diff --git a/app/services/wallet_service.py b/app/services/wallet_service.py new file mode 100644 index 0000000..1f96e8c --- /dev/null +++ b/app/services/wallet_service.py @@ -0,0 +1,209 @@ +from datetime import datetime, UTC +from typing import Dict, Optional +from decimal import Decimal +import asyncio +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession +from fastapi import HTTPException + +from app.core.logger import get_logger +from app.models.wallet import Wallet +from app.models.provider_key import ProviderKey + +logger = get_logger(name="wallet_service") + +# Retry configuration +MAX_RETRIES = 3 +RETRY_DELAY_MS = 10 # milliseconds + +class WalletService: + @staticmethod + async def ensure_wallet(db: AsyncSession, account_id: int) -> None: + """Create wallet if it doesn't exist""" + try: + result = await db.execute(select(Wallet).filter(Wallet.account_id == account_id)) + if result.scalar_one_or_none() is None: + wallet = Wallet(account_id=account_id) + db.add(wallet) + await db.commit() + logger.debug(f"Created wallet for account {account_id}") + except Exception as e: + await db.rollback() + logger.error(f"Failed to ensure wallet for account {account_id}: {e}") + raise + + @staticmethod + async def precheck(db: AsyncSession, account_id: int) -> Dict[str, any]: + """Check if user can make requests""" + try: + result = await db.execute(select(Wallet).filter(Wallet.account_id == account_id)) + wallet = result.scalar_one_or_none() + + if not wallet: + await WalletService.ensure_wallet(db, account_id) + return {"blocked": False, "balance": Decimal("0"), "allowed": False} + + allowed = not wallet.blocked and wallet.balance > 0 + return { + "blocked": wallet.blocked, + "balance": wallet.balance, + "allowed": allowed + } + except Exception as e: + logger.exception(f"Precheck failed for account {account_id}: {e}") + raise + + @staticmethod + async def adjust( + db: AsyncSession, + account_id: int, + delta: Decimal, + reason: str, + currency: str = "USD" + ) -> Dict[str, any]: + """Adjust wallet balance with optimistic locking and retry""" + for attempt in range(MAX_RETRIES): + try: + # Read current wallet state including version + result = await db.execute(select(Wallet).where(Wallet.account_id == account_id)) + wallet = result.scalar_one_or_none() + + if not wallet: + await WalletService.ensure_wallet(db, account_id) + continue # Retry after creating wallet + + current_version = wallet.version + + # Attempt optimistic update - always allow deductions (oversubscription is OK) + stmt = update(Wallet).where( + (Wallet.account_id == account_id) & + (Wallet.version == current_version) + ).values( + balance=Wallet.balance + delta, + updated_at=datetime.now(UTC), + version=Wallet.version + 1 + ).returning(Wallet.balance, Wallet.blocked) + + result = await db.execute(stmt) + row = result.fetchone() + + if row is None: + # Version conflict - another process updated first + if attempt < MAX_RETRIES - 1: + await db.rollback() + await asyncio.sleep(RETRY_DELAY_MS / 1000.0) + logger.debug(f"Optimistic lock conflict for account {account_id}, retrying ({attempt + 1}/{MAX_RETRIES})") + continue + else: + await db.rollback() + logger.warning(f"Max retries exceeded for account {account_id} adjustment") + return {"success": False, "reason": "version_conflict"} + + await db.commit() + logger.debug(f"Adjusted balance for account {account_id} by {delta} ({reason}) after {attempt + 1} attempts") + return {"success": True, "balance": row[0], "blocked": row[1]} + + except Exception as e: + await db.rollback() + if attempt < MAX_RETRIES - 1: + logger.warning(f"Error adjusting balance for account {account_id}, retrying: {e}") + await asyncio.sleep(RETRY_DELAY_MS / 1000.0) + continue + else: + logger.error(f"Failed to adjust balance for account {account_id} after {MAX_RETRIES} attempts: {e}") + raise + + return {"success": False, "reason": "max_retries_exceeded"} + + @staticmethod + async def set_blocked(db: AsyncSession, account_id: int, blocked: bool) -> None: + """Block or unblock account with optimistic locking""" + for attempt in range(MAX_RETRIES): + try: + # Read current version + result = await db.execute(select(Wallet).where(Wallet.account_id == account_id)) + wallet = result.scalar_one_or_none() + + if not wallet: + logger.warning(f"Wallet not found for account {account_id}") + return + + current_version = wallet.version + + # Update with version check + result = await db.execute( + update(Wallet).where( + (Wallet.account_id == account_id) & + (Wallet.version == current_version) + ).values( + blocked=blocked, + updated_at=datetime.now(UTC), + version=Wallet.version + 1 + ) + ) + + if result.rowcount == 0: + # Version conflict + if attempt < MAX_RETRIES - 1: + await db.rollback() + await asyncio.sleep(RETRY_DELAY_MS / 1000.0) + logger.debug(f"Optimistic lock conflict setting blocked status for account {account_id}, retrying") + continue + else: + await db.rollback() + logger.warning(f"Failed to set blocked status after {MAX_RETRIES} attempts for account {account_id}") + return + + await db.commit() + logger.debug(f"Set blocked={blocked} for account {account_id} after {attempt + 1} attempts") + return + + except Exception as e: + await db.rollback() + if attempt < MAX_RETRIES - 1: + logger.warning(f"Error setting blocked status for account {account_id}, retrying: {e}") + await asyncio.sleep(RETRY_DELAY_MS / 1000.0) + continue + else: + logger.error(f"Failed to set blocked status for account {account_id} after {MAX_RETRIES} attempts: {e}") + raise + + @staticmethod + async def get(db: AsyncSession, account_id: int) -> Optional[Dict[str, any]]: + """Get wallet details""" + try: + result = await db.execute(select(Wallet).filter(Wallet.account_id == account_id)) + wallet = result.scalar_one_or_none() + + if not wallet: + return None + + return { + "balance": wallet.balance, + "blocked": wallet.blocked, + "currency": wallet.currency + } + except Exception as e: + logger.error(f"Failed to get wallet for account {account_id}: {e}") + raise + + # ------------------------------------------------------------- + # Helper: perform wallet precheck + # ------------------------------------------------------------- + @staticmethod + async def wallet_precheck(user_id: int, db: AsyncSession, provider_key_id: int) -> None: + """Check wallet balance and ensure user can make requests""" + provider_key = await db.execute(select(ProviderKey).filter(ProviderKey.id == provider_key_id, ProviderKey.billable)) + provider_key = provider_key.scalar_one_or_none() + # If the provider key is not billable, we don't need to check the wallet + if not provider_key: + return + + await WalletService.ensure_wallet(db, user_id) + check_result = await WalletService.precheck(db, user_id) + + if not check_result["allowed"]: + if check_result["blocked"]: + raise HTTPException(status_code=402, detail="Account blocked") + else: + raise HTTPException(status_code=402, detail="Insufficient balance") \ No newline at end of file