Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions alembic/versions/a58395ea1b22_add_balance_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""add balance system

Revision ID: a58395ea1b22
Revises: c9f3e548adef
Create Date: 2025-08-20 22:00:45.743308

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'a58395ea1b22'
down_revision = 'c9f3e548adef'
branch_labels = None
depends_on = None

def upgrade() -> None:
op.create_table(
'wallets',
sa.Column('account_id', sa.BigInteger(), nullable=False),
sa.Column('currency', sa.CHAR(length=3), nullable=False, server_default='USD'),
sa.Column('balance', sa.DECIMAL(precision=20, scale=6), nullable=False, server_default='0'),
sa.Column('blocked', sa.Boolean(), nullable=False, server_default='FALSE'),
sa.Column('version', sa.BigInteger(), nullable=False, server_default='0'),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('now()')),
sa.PrimaryKeyConstraint('account_id'),
sa.ForeignKeyConstraint(['account_id'], ['users.id'], ondelete='CASCADE')
)
op.add_column('provider_keys', sa.Column('billable', sa.Boolean(), nullable=False, server_default='FALSE'))

def downgrade() -> None:
op.drop_table('wallets')
op.drop_column('provider_keys', 'billable')
37 changes: 37 additions & 0 deletions app/api/routes/wallet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from decimal import Decimal
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel

from app.api.dependencies import get_current_active_user, get_current_active_user_from_clerk
from app.core.database import get_async_db
from app.models.user import User
from app.services.wallet_service import WalletService

router = APIRouter()

class WalletResponse(BaseModel):
balance: Decimal
blocked: bool
currency: str

@router.get("/balance", response_model=WalletResponse)
async def get_wallet_balance(
user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_async_db)
):
"""Get current wallet balance"""
wallet = await WalletService.get(db, user.id)

if not wallet:
await WalletService.ensure_wallet(db, user.id)
return WalletResponse(balance=Decimal("0"), blocked=False, currency="USD")

return WalletResponse(**wallet)

@router.get("/balance/clerk", response_model=WalletResponse)
async def get_wallet_balance_clerk(
user: User = Depends(get_current_active_user_from_clerk),
db: AsyncSession = Depends(get_async_db)
):
return await get_wallet_balance(user, db)
138 changes: 137 additions & 1 deletion app/api/routes/webhooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@
from app.core.security import generate_forge_api_key
from app.models.user import User
from app.services.provider_service import create_default_tensorblock_provider_for_user
from app.services.wallet_service import WalletService

logger = get_logger(name="webhooks")

router = APIRouter()

# Clerk webhook signing secret for verifying webhook authenticity
# Webhook signing secrets for verifying webhook authenticity
CLERK_WEBHOOK_SECRET = os.getenv("CLERK_WEBHOOK_SECRET", "")
STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET", "")


@router.post("/clerk")
Expand Down Expand Up @@ -238,3 +240,137 @@ async def handle_user_deleted(event_data: dict, db: AsyncSession):
await db.rollback()
logger.error(f"Failed to delete user from webhook: {e}", exc_info=True)
raise


@router.post("/stripe")
async def stripe_webhook_handler(request: Request, db: AsyncSession = Depends(get_async_db)):
"""
Handle Stripe webhooks for payment events.

Key events to handle:
- payment_intent.succeeded: Credit wallet balance
- payment_intent.payment_failed: Log failed payment
- invoice.payment_failed: Handle subscription payment failure
"""
# Get the request body and signature
payload = await request.body()
sig_header = request.headers.get("stripe-signature")

if not sig_header:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing Stripe signature header"
)

# NOTE: For production, you would verify the webhook signature here
# This is a placeholder for the Stripe webhook verification
# Example:
# import stripe
# try:
# event = stripe.Webhook.construct_event(
# payload, sig_header, STRIPE_WEBHOOK_SECRET
# )
# except ValueError:
# raise HTTPException(status_code=400, detail="Invalid payload")
# except stripe.error.SignatureVerificationError:
# raise HTTPException(status_code=400, detail="Invalid signature")

try:
# For now, parse as JSON (would use verified event in production)
event_data = json.loads(payload)
event_type = event_data.get("type")

logger.info(f"Received Stripe webhook: {event_type}")

# Handle different event types
if event_type == "payment_intent.succeeded":
await handle_payment_succeeded(event_data, db)
elif event_type == "payment_intent.payment_failed":
await handle_payment_failed(event_data, db)
elif event_type == "invoice.payment_failed":
await handle_invoice_payment_failed(event_data, db)
else:
logger.info(f"Unhandled Stripe event type: {event_type}")

return {"status": "success", "message": f"Event {event_type} processed"}

except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid JSON payload"
)
except Exception as e:
logger.error(f"Error processing Stripe webhook: {e}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error processing webhook: {str(e)}",
)


async def handle_payment_succeeded(event_data: dict, db: AsyncSession):
"""Handle successful payment - credit wallet balance"""
try:
payment_intent = event_data.get("data", {}).get("object", {})
amount = payment_intent.get("amount", 0) # Amount in cents
currency = payment_intent.get("currency", "usd").upper()
customer_id = payment_intent.get("customer")

# Convert cents to dollars for USD
if currency == "USD":
amount_decimal = amount / 100.0
else:
amount_decimal = amount # Handle other currencies as needed

# TODO: Map customer_id to user_id
# For now, this is a placeholder - you'd need to implement customer mapping
# user_id = get_user_id_from_stripe_customer(customer_id)

logger.info(f"Payment succeeded: {amount_decimal} {currency} for customer {customer_id}")

# Uncomment when customer mapping is implemented:
# await WalletService.adjust(
# db,
# user_id,
# amount_decimal,
# f"deposit:stripe:{payment_intent.get('id')}",
# currency
# )

except Exception as e:
logger.error(f"Failed to process payment success: {e}", exc_info=True)
raise


async def handle_payment_failed(event_data: dict, db: AsyncSession):
"""Handle failed payment"""
try:
payment_intent = event_data.get("data", {}).get("object", {})
customer_id = payment_intent.get("customer")

logger.warning(f"Payment failed for customer {customer_id}")

# TODO: Implement failure handling logic
# - Notify user
# - Update payment status
# - Handle retry logic

except Exception as e:
logger.error(f"Failed to process payment failure: {e}", exc_info=True)
raise


async def handle_invoice_payment_failed(event_data: dict, db: AsyncSession):
"""Handle failed invoice payment - may need to block account"""
try:
invoice = event_data.get("data", {}).get("object", {})
customer_id = invoice.get("customer")

logger.warning(f"Invoice payment failed for customer {customer_id}")

# TODO: Map customer_id to user_id and potentially block account
# user_id = get_user_id_from_stripe_customer(customer_id)
# await WalletService.set_blocked(db, user_id, True)

except Exception as e:
logger.error(f"Failed to process invoice payment failure: {e}", exc_info=True)
raise
2 changes: 2 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
statistic,
stats,
users,
wallet,
webhooks,
)
from app.core.database import engine
Expand Down Expand Up @@ -167,6 +168,7 @@ def create_app() -> FastAPI:
v1_router.include_router(api_keys.router, prefix="/api-keys", tags=["api-keys"])
v1_router.include_router(proxy.router, tags=["proxy"])
v1_router.include_router(stats.router, prefix="/stats", tags=["stats"])
v1_router.include_router(wallet.router, prefix="/wallet", tags=["wallet"])
v1_router.include_router(webhooks.router, prefix="/webhooks", tags=["webhooks"])
v1_router.include_router(statistic.router, prefix='/statistic', tags=["statistic"])
# Claude Code compatible API endpoints
Expand Down
3 changes: 2 additions & 1 deletion app/models/provider_key.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import Column, ForeignKey, Integer, String, JSON
from sqlalchemy import Column, ForeignKey, Integer, String, JSON, Boolean
from sqlalchemy.orm import relationship

from app.models.forge_api_key import forge_api_key_provider_scope_association
Expand All @@ -19,6 +19,7 @@ class ProviderKey(BaseModel):
String, nullable=True
) # Allow custom base URLs for some providers
model_mapping = Column(JSON, nullable=True) # JSON dict for model name mappings
billable = Column(Boolean, nullable=False, default=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add a "billable" annotation at the model level instead of the provider level?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an open discussion. My current implementation is based on the assumption that the billing object is at "provider" level. You could have "free model" by setting the price to be zero.


# Relationship to ForgeApiKeys that have this provider key in their scope
scoped_forge_api_keys = relationship(
Expand Down
1 change: 1 addition & 0 deletions app/models/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ class User(Base):
provider_keys = relationship(
"ProviderKey", back_populates="user", cascade="all, delete-orphan"
)
wallet = relationship("Wallet", back_populates="user", uselist=False)
# Optional: Add relationship to ApiRequestLog if needed
# api_logs = relationship("ApiRequestLog")
17 changes: 17 additions & 0 deletions app/models/wallet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from datetime import datetime, UTC
from sqlalchemy import Column, BigInteger, CHAR, DECIMAL, Boolean, DateTime, ForeignKey
from sqlalchemy.orm import relationship
from .base import Base

class Wallet(Base):
__tablename__ = "wallets"

account_id = Column(BigInteger, ForeignKey("users.id", ondelete="CASCADE"), primary_key=True)
currency = Column(CHAR(3), nullable=False, default='USD')
balance = Column(DECIMAL(20, 6), nullable=False, default=0)
blocked = Column(Boolean, nullable=False, default=False)
version = Column(BigInteger, nullable=False, default=0)
created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.now(UTC))
updated_at = Column(DateTime(timezone=True), nullable=False, default=datetime.now(UTC))

user = relationship("User", back_populates="wallet")
4 changes: 3 additions & 1 deletion app/services/provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import time
from collections.abc import AsyncGenerator
from typing import Any, ClassVar

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

Expand All @@ -20,6 +19,7 @@
)
from app.models.user import User
from app.core.database import get_db_session
from app.services.wallet_service import WalletService

from .providers.adapter_factory import ProviderAdapterFactory
from .providers.base import ProviderAdapter
Expand Down Expand Up @@ -597,6 +597,7 @@ async def process_request(
# Process the request through the adapter
usage_tracker_id = None
if self.api_key_id is not None and provider_key_id is not None:
await WalletService.wallet_precheck(self.user_id, self.db, provider_key_id)
usage_tracker_id = await UsageTrackerService.start_tracking_usage(
db=self.db,
user_id=self.user_id,
Expand Down Expand Up @@ -854,6 +855,7 @@ async def create_default_tensorblock_provider_for_user(
encrypted_api_key=encrypt_api_key(serialized_api_key_config),
user_id=user_id,
base_url=tensorblock_base_url,
billable=True,
model_mapping=None, # TensorBlock adapter handles model mapping internally
)

Expand Down
17 changes: 17 additions & 0 deletions app/services/providers/usage_tracker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from app.core.logger import get_logger
from app.models.usage_tracker import UsageTracker
from app.services.pricing_service import PricingService
from app.services.wallet_service import WalletService

logger = get_logger(name="usage_tracker")

Expand Down Expand Up @@ -80,6 +81,22 @@ async def update_usage_tracker(
usage_tracker.cost = price_info['total_cost']
usage_tracker.currency = price_info['currency']
usage_tracker.pricing_source = price_info['pricing_source']

# Deduct from wallet balance if the provider is not free
if price_info['total_cost'] and price_info['total_cost'] > 0 and usage_tracker.provider_key.billable:
try:
result = await WalletService.adjust(
db,
usage_tracker.user_id,
-price_info['total_cost'],
f"usage:{usage_tracker.endpoint}",
price_info['currency']
)
if not result.get("success"):
logger.warning(f"Failed to deduct from wallet for user {usage_tracker.user_id}: {result.get('reason')}")
except Exception as wallet_err:
logger.exception(f"Wallet deduction failed for user {usage_tracker.user_id}: {wallet_err}")

await db.commit()
logger.debug(f"Updated usage tracker {usage_tracker_id} with input_tokens {input_tokens}, output_tokens {output_tokens}, cached_tokens {cached_tokens}, reasoning_tokens {reasoning_tokens}")
except NoResultFound:
Expand Down
Loading
Loading