Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions alembic/versions/40e4b59f754d_add_stripe_payment_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""add stripe payment table

Revision ID: 40e4b59f754d
Revises: a58395ea1b22
Create Date: 2025-09-02 20:52:29.183031

"""
from alembic import op
import sqlalchemy as sa
from datetime import datetime, UTC


# revision identifiers, used by Alembic.
revision = '40e4b59f754d'
down_revision = 'a58395ea1b22'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.create_table('stripe_payment',
sa.Column('id', sa.String, primary_key=True),
sa.Column('user_id', sa.Integer, sa.ForeignKey('users.id'), nullable=False),
sa.Column('amount', sa.Integer, nullable=False),
sa.Column('currency', sa.String(3), nullable=False),
sa.Column('status', sa.String, nullable=False),
sa.Column('raw_data', sa.JSON, nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), default=datetime.now(UTC)),
sa.Column('updated_at', sa.DateTime(timezone=True), default=datetime.now(UTC), onupdate=datetime.now(UTC)),
)


def downgrade() -> None:
op.drop_table('stripe_payment')
78 changes: 78 additions & 0 deletions app/api/routes/stripe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import os
from fastapi import APIRouter, Depends, Request
from app.api.dependencies import get_current_active_user_from_clerk, get_current_active_user
from app.api.schemas.stripe import CreateCheckoutSessionRequest
from app.models.user import User
from app.models.stripe import StripePayment
from app.core.database import get_async_db
from sqlalchemy.ext.asyncio import AsyncSession
import stripe
from app.core.logger import get_logger
from sqlalchemy import select
from fastapi import HTTPException

logger = get_logger(name="stripe")

STRIPE_API_KEY = os.getenv("STRIPE_API_KEY")
stripe.api_key = STRIPE_API_KEY

router = APIRouter()

@router.post("/create-checkout-session/clerk")
async def stripe_create_checkout_session_clerk(request: Request, create_checkout_session_request: CreateCheckoutSessionRequest, user: User = Depends(get_current_active_user_from_clerk), db: AsyncSession = Depends(get_async_db)):
return await stripe_create_checkout_session(request, create_checkout_session_request, user, db)


@router.post("/create-checkout-session")
async def stripe_create_checkout_session(request: Request, create_checkout_session_request: CreateCheckoutSessionRequest, user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_async_db)):
"""
Create a checkout session for a user.
"""
logger.info(f"Creating checkout session for user {user.id}")
session = await stripe.checkout.Session.create_async(
metadata={
"user_id": user.id,
},
**create_checkout_session_request.model_dump(exclude_none=True),
)
stripe_payment = StripePayment(
id=session.id,
user_id=user.id,
status=session.status,
currency=session.currency.upper(),
amount=session.amount_total,
# store the whole session as raw_data
raw_data=dict(session),
)
db.add(stripe_payment)
await db.commit()

return {
'session_id': session.id,
'url': session.url,
}

@router.get("/checkout-session")
async def stripe_get_checkout_session(session_id: str, user: User = Depends(get_current_active_user), db: AsyncSession = Depends(get_async_db)):
result = await db.execute(
select(
StripePayment
)
.where(StripePayment.id == session_id, StripePayment.user_id == user.id)
)
stripe_payment = result.scalar_one_or_none()
if not stripe_payment:
raise HTTPException(status_code=404, detail="Stripe payment not found")

return {
'id': stripe_payment.id,
'status': stripe_payment.status,
'currency': stripe_payment.currency,
'amount': stripe_payment.amount / 100.0 if stripe_payment.currency == "USD" else stripe_payment.amount,
'created_at': stripe_payment.created_at,
'updated_at': stripe_payment.updated_at,
}

@router.get("/checkout-session/clerk")
async def stripe_get_checkout_session_clerk(session_id: str, user: User = Depends(get_current_active_user_from_clerk), db: AsyncSession = Depends(get_async_db)):
return await stripe_get_checkout_session(session_id, user, db)
84 changes: 81 additions & 3 deletions app/api/routes/wallet.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
from decimal import Decimal
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from typing import List
from datetime import datetime

from app.api.dependencies import get_current_active_user, get_current_active_user_from_clerk
from app.core.database import get_async_db
from app.models.user import User
from app.models.stripe import StripePayment
from app.models.usage_tracker import UsageTracker
from app.services.wallet_service import WalletService
from sqlalchemy import select, desc, func

router = APIRouter()

class WalletResponse(BaseModel):
balance: Decimal
blocked: bool
currency: str
total_spent: Decimal
total_earned: Decimal

@router.get("/balance", response_model=WalletResponse)
async def get_wallet_balance(
Expand All @@ -25,13 +32,84 @@ async def get_wallet_balance(

if not wallet:
await WalletService.ensure_wallet(db, user.id)
return WalletResponse(balance=Decimal("0"), blocked=False, currency="USD")
return WalletResponse(balance=Decimal("0"), blocked=False, currency="USD", total_spent=Decimal("0"), total_earned=Decimal("0"))

return WalletResponse(**wallet)
result = await db.execute(select(func.sum(UsageTracker.cost)).where(UsageTracker.user_id == user.id, UsageTracker.updated_at.is_not(None)))
total_spent = result.scalar_one_or_none() or "0"
result = await db.execute(select(func.sum(StripePayment.amount)).where(StripePayment.user_id == user.id, StripePayment.status == "completed"))
total_earned = result.scalar_one_or_none() or "0"

return WalletResponse(**wallet, total_spent=Decimal(total_spent), total_earned=Decimal(total_earned))

@router.get("/balance/clerk", response_model=WalletResponse)
async def get_wallet_balance_clerk(
user: User = Depends(get_current_active_user_from_clerk),
db: AsyncSession = Depends(get_async_db)
):
return await get_wallet_balance(user, db)

class TransactionHistoryItem(BaseModel):
currency: str
amount: Decimal
status: str
created_at: datetime
updated_at: datetime

class TransactionHistoryResponse(BaseModel):
items: List[TransactionHistoryItem]
total: int
page_size: int
page_index: int

@router.get("/transactions/history", response_model=TransactionHistoryResponse)
async def get_wallet_transactions_history(
user: User = Depends(get_current_active_user),
db: AsyncSession = Depends(get_async_db),
page_size: int = Query(10, ge=1),
page_index: int = Query(0, ge=0),
status: str = Query(None, min_length=1),
started_at: datetime = Query(None),
):
# I would also want to get the total count of the transactions within one sql query
query = (
select(
StripePayment.currency,
StripePayment.amount,
StripePayment.status,
StripePayment.created_at,
StripePayment.updated_at,
func.count().over().label("total"),
)
.where(StripePayment.user_id == user.id, status is None or StripePayment.status == status, started_at is None or StripePayment.created_at >= started_at)
.order_by(desc(StripePayment.updated_at))
.offset(page_index * page_size)
.limit(page_size)
)
result = await db.execute(query)
transactions = result.fetchall()
return TransactionHistoryResponse(
items=[
TransactionHistoryItem(
currency=transaction.currency,
# Convert cents to dollars for USD
amount=transaction.amount / 100.0 if transaction.currency == "USD" else transaction.amount,
status=transaction.status,
created_at=transaction.created_at,
updated_at=transaction.updated_at,
)
for transaction in transactions],
total=transactions[0].total if transactions else 0,
page_size=page_size,
page_index=page_index,
)

@router.get("/transactions/history/clerk", response_model=TransactionHistoryResponse)
async def get_wallet_transactions_history_clerk(
user: User = Depends(get_current_active_user_from_clerk),
db: AsyncSession = Depends(get_async_db),
page_size: int = Query(10, ge=1),
page_index: int = Query(0, ge=0),
status: str = Query(None, min_length=1),
started_at: datetime = Query(None),
):
return await get_wallet_transactions_history(user, db, page_size, page_index, status, started_at)
Loading
Loading