Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ tests/performance/results/
.coverage*

# Local test script
tools/local_test_script.py
tools/tmp/*
182 changes: 182 additions & 0 deletions alembic/versions/1876c1c4bc96_add_price_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""add price table

Revision ID: 1876c1c4bc96
Revises: 831fc2cf16ee
Create Date: 2025-08-11 17:57:04.438535

"""
from alembic import op
import sqlalchemy as sa
from csv import DictReader
from datetime import datetime, UTC, timedelta
import os
import decimal

# revision identifiers, used by Alembic.
revision = '1876c1c4bc96'
down_revision = '831fc2cf16ee'
branch_labels = None
depends_on = None

def upgrade() -> None:
# Create model_pricing table
op.create_table(
'model_pricing',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('CURRENT_TIMESTAMP')),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('CURRENT_TIMESTAMP')),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('provider_name', sa.String(), nullable=False),
sa.Column('model_name', sa.String(), nullable=False),
sa.Column('effective_date', sa.DateTime(timezone=True), nullable=False),
sa.Column('end_date', sa.DateTime(timezone=True), nullable=True),
sa.Column('input_token_price', sa.DECIMAL(precision=12, scale=8), nullable=False),
sa.Column('output_token_price', sa.DECIMAL(precision=12, scale=8), nullable=False),
sa.Column('cached_token_price', sa.DECIMAL(precision=12, scale=8), nullable=False, server_default=sa.text('0')),
sa.Column('currency', sa.String(length=3), nullable=False, server_default='USD'),
sa.Column('price_source', sa.String(length=50), nullable=False, server_default='manual'),
sa.PrimaryKeyConstraint('id')
)

# Create indexes for model_pricing
op.create_index('ix_model_pricing_active', 'model_pricing',
['provider_name', 'model_name', 'effective_date', 'end_date'])
op.create_index('ix_model_pricing_temporal', 'model_pricing',
['effective_date', 'end_date'])
op.create_index('ix_model_pricing_unique_period', 'model_pricing',
['provider_name', 'model_name', 'effective_date'], unique=True)
op.create_index(op.f('ix_model_pricing_provider_name'), 'model_pricing', ['provider_name'])
op.create_index(op.f('ix_model_pricing_model_name'), 'model_pricing', ['model_name'])
op.create_index(op.f('ix_model_pricing_id'), 'model_pricing', ['id'])

# Create fallback_pricing table
op.create_table(
'fallback_pricing',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('CURRENT_TIMESTAMP')),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False, server_default=sa.text('CURRENT_TIMESTAMP')),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('provider_name', sa.String(), nullable=True),
sa.Column('fallback_type', sa.String(length=20), nullable=False),
sa.Column('effective_date', sa.DateTime(timezone=True), nullable=False),
sa.Column('end_date', sa.DateTime(timezone=True), nullable=True),
sa.Column('input_token_price', sa.DECIMAL(precision=12, scale=8), nullable=False),
sa.Column('output_token_price', sa.DECIMAL(precision=12, scale=8), nullable=False),
sa.Column('cached_token_price', sa.DECIMAL(precision=12, scale=8), nullable=False, server_default=sa.text('0')),
sa.Column('currency', sa.String(length=3), nullable=False, server_default='USD'),
sa.Column('description', sa.String(length=255), nullable=True),
sa.PrimaryKeyConstraint('id')
)

# Create indexes for fallback_pricing
op.create_index('ix_fallback_pricing_active', 'fallback_pricing',
['provider_name', 'fallback_type', 'effective_date', 'end_date'])
op.create_index('ix_fallback_pricing_type', 'fallback_pricing',
['fallback_type', 'effective_date'])
op.create_index(op.f('ix_fallback_pricing_provider_name'), 'fallback_pricing', ['provider_name'])
op.create_index(op.f('ix_fallback_pricing_fallback_type'), 'fallback_pricing', ['fallback_type'])
op.create_index(op.f('ix_fallback_pricing_id'), 'fallback_pricing', ['id'])

# Insert model pricing data
effective_date = datetime.now(UTC) - timedelta(days=30)
csv_path = os.path.join(os.path.dirname(__file__), "..", "..", "tools", "data", "model_pricing_init.csv")
with open(csv_path, "r") as f:
reader = DictReader(f)
rows_to_insert = []
for row in reader:
rows_to_insert.append({
"provider_name": row["provider_name"],
"model_name": row["model_name"],
"effective_date": effective_date,
"input_token_price": (decimal.Decimal(str(row["input_token_price"])) * 1000).normalize(),
"output_token_price": (decimal.Decimal(str(row["output_token_price"])) * 1000).normalize(),
"price_source": "manual"
})
if rows_to_insert:
connection = op.get_bind()
connection.execute(
sa.text("""
INSERT INTO model_pricing (provider_name, model_name, effective_date, input_token_price, output_token_price, price_source)
VALUES (:provider_name, :model_name, :effective_date, :input_token_price, :output_token_price, 'manual')
"""),
rows_to_insert,
)

# Insert some initial fallback pricing data
# For all the providers in the model_pricing table, insert a fallback pricing record with the provider_default fallback_type, set the prcie to be the average of the input_token_price and output_token_price
# For global fallback, set the provider_name to NULL, and the fallback_type to global_default, and the price to be the average of the input_token_price and output_token_price of all the providers in the model_pricing table
# The effective_date should be the same as the effective_date of the model_pricing table

# Get all unique providers from model_pricing table
providers_result = connection.execute(
sa.text("SELECT DISTINCT provider_name FROM model_pricing")
).fetchall()

fallback_rows = []

# Insert provider-specific fallback pricing
for provider_row in providers_result:
provider_name = provider_row[0]

# Calculate average prices for this provider
avg_prices_result = connection.execute(
sa.text("""
SELECT
AVG(input_token_price) as avg_input_price,
AVG(output_token_price) as avg_output_price
FROM model_pricing
WHERE provider_name = :provider_name
"""),
{"provider_name": provider_name}
).fetchone()

avg_input_price = avg_prices_result[0]
avg_output_price = avg_prices_result[1]

fallback_rows.append({
"provider_name": provider_name,
"fallback_type": "provider_default",
"effective_date": effective_date,
"input_token_price": avg_input_price,
"output_token_price": avg_output_price,
"description": f"Default pricing for {provider_name} provider"
})

# Calculate global average prices
global_avg_result = connection.execute(
sa.text("""
SELECT
AVG(input_token_price) as avg_input_price,
AVG(output_token_price) as avg_output_price
FROM model_pricing
""")
).fetchone()

global_avg_input_price = global_avg_result[0]
global_avg_output_price = global_avg_result[1]

# Insert global fallback pricing
fallback_rows.append({
"provider_name": None,
"fallback_type": "global_default",
"effective_date": effective_date,
"input_token_price": global_avg_input_price,
"output_token_price": global_avg_output_price,
"description": "Global default pricing for all providers"
})

# Insert fallback pricing data
if fallback_rows:
connection.execute(
sa.text("""
INSERT INTO fallback_pricing (provider_name, fallback_type, effective_date, input_token_price, output_token_price, description)
VALUES (:provider_name, :fallback_type, :effective_date, :input_token_price, :output_token_price, :description)
"""),
fallback_rows,
)


def downgrade() -> None:
# Drop tables in reverse order
op.drop_table('fallback_pricing')
op.drop_table('model_pricing')
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""add cost tracking to usage_tracker table

Revision ID: b206e9a941e3
Revises: 1876c1c4bc96
Create Date: 2025-08-11 18:19:08.581296

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = 'b206e9a941e3'
down_revision = '1876c1c4bc96'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column('usage_tracker', sa.Column('cost', sa.DECIMAL(precision=12, scale=8), nullable=True))
op.add_column('usage_tracker', sa.Column('currency', sa.String(length=3), nullable=True))
op.add_column('usage_tracker', sa.Column('pricing_source', sa.String(length=255), nullable=True))


def downgrade() -> None:
op.drop_column('usage_tracker', 'cost')
op.drop_column('usage_tracker', 'currency')
op.drop_column('usage_tracker', 'pricing_source')
67 changes: 67 additions & 0 deletions app/models/pricing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# app/models/pricing.py
import datetime
from datetime import UTC
from sqlalchemy import Column, DateTime, String, DECIMAL, Index

from .base import BaseModel


class ModelPricing(BaseModel):
"""
Store pricing information for specific models with temporal support
"""
__tablename__ = "model_pricing"

provider_name = Column(String, nullable=False, index=True)
model_name = Column(String, nullable=False, index=True)

# Temporal fields for price changes over time
effective_date = Column(DateTime(timezone=True), nullable=False, default=datetime.datetime.now(UTC))
end_date = Column(DateTime(timezone=True), nullable=True) # NULL means currently active

# Pricing per 1K tokens (using DECIMAL for precision)
input_token_price = Column(DECIMAL(12, 8), nullable=False) # Price per 1K input tokens
output_token_price = Column(DECIMAL(12, 8), nullable=False) # Price per 1K output tokens
cached_token_price = Column(DECIMAL(12, 8), nullable=False, default=0) # Price per 1K cached tokens

# Metadata
currency = Column(String(3), nullable=False, default='USD')

# Indexes for efficient querying
__table_args__ = (
# Index for finding active pricing for a model
Index('ix_model_pricing_active', 'provider_name', 'model_name', 'effective_date', 'end_date'),
# Index for temporal queries
Index('ix_model_pricing_temporal', 'effective_date', 'end_date'),
# Unique constraint for overlapping periods (business rule enforcement)
Index('ix_model_pricing_unique_period', 'provider_name', 'model_name', 'effective_date', unique=True),
)


class FallbackPricing(BaseModel):
"""
Store fallback pricing for providers and global defaults
"""
__tablename__ = "fallback_pricing"

provider_name = Column(String, nullable=True, index=True) # NULL for global fallback
fallback_type = Column(String(20), nullable=False, index=True) # 'provider_default', 'global_default'

# Temporal fields
effective_date = Column(DateTime(timezone=True), nullable=False, default=datetime.datetime.now(UTC))
end_date = Column(DateTime(timezone=True), nullable=True)

# Pricing per 1K tokens
input_token_price = Column(DECIMAL(12, 8), nullable=False)
output_token_price = Column(DECIMAL(12, 8), nullable=False)
cached_token_price = Column(DECIMAL(12, 8), nullable=False, default=0)

# Metadata
currency = Column(String(3), nullable=False, default='USD')
description = Column(String(255), nullable=True) # Optional description

# Indexes
__table_args__ = (
Index('ix_fallback_pricing_active', 'provider_name', 'fallback_type', 'effective_date', 'end_date'),
Index('ix_fallback_pricing_type', 'fallback_type', 'effective_date'),
)
1 change: 1 addition & 0 deletions app/models/provider_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ class ProviderKey(BaseModel):
back_populates="allowed_provider_keys",
lazy="selectin",
)
usage_tracker = relationship("UsageTracker", back_populates="provider_key")
8 changes: 7 additions & 1 deletion app/models/usage_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from datetime import UTC
import uuid

from sqlalchemy import Column, DateTime, ForeignKey, Integer, String
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, DECIMAL
from sqlalchemy.orm import relationship
from sqlalchemy.dialects.postgresql import UUID
from .base import Base

Expand All @@ -21,3 +22,8 @@ class UsageTracker(Base):
output_tokens = Column(Integer, nullable=True)
cached_tokens = Column(Integer, nullable=True)
reasoning_tokens = Column(Integer, nullable=True)
cost = Column(DECIMAL(12, 8), nullable=True)
currency = Column(String(3), nullable=True)
pricing_source = Column(String(255), nullable=True)

provider_key = relationship("ProviderKey", back_populates="usage_tracker")
Loading
Loading