Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions alembic/versions/39bcedfae4fe_add_model_default_pricing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""add model default pricing

Revision ID: 39bcedfae4fe
Revises: b206e9a941e3
Create Date: 2025-08-14 18:31:20.897283

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '39bcedfae4fe'
down_revision = 'b206e9a941e3'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column('fallback_pricing', sa.Column('model_name', sa.String(), nullable=True))
connection = op.get_bind()
connection.execute(sa.text("""
insert into fallback_pricing (provider_name, model_name, input_token_price, output_token_price, cached_token_price, currency, created_at, updated_at, effective_date, end_date, fallback_type, description)
select distinct on (model_name)
provider_name as provider_name,
model_name as model_name,
input_token_price,
output_token_price,
cached_token_price,
currency,
created_at,
updated_at,
effective_date,
end_date,
'model_default' as fallback_type,
null as description
from model_pricing
order by model_name,
case when provider_name = 'openai' then 1
when provider_name = 'anthropic' then 2
else 3 end
"""))


def downgrade() -> None:
connection = op.get_bind()
connection.execute(sa.text("""
delete from fallback_pricing where fallback_type = 'model_default'
"""))
op.drop_column('fallback_pricing', 'model_name')
3 changes: 2 additions & 1 deletion app/models/pricing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class FallbackPricing(BaseModel):
__tablename__ = "fallback_pricing"

provider_name = Column(String, nullable=True, index=True) # NULL for global fallback
fallback_type = Column(String(20), nullable=False, index=True) # 'provider_default', 'global_default'
model_name = Column(String, nullable=True, index=True) # NULL for global fallback
fallback_type = Column(String(20), nullable=False, index=True) # 'model_default', 'provider_default', 'global_default'

# Temporal fields
effective_date = Column(DateTime(timezone=True), nullable=False, default=datetime.datetime.now(UTC))
Expand Down
131 changes: 114 additions & 17 deletions app/services/pricing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from app.core.async_cache import async_provider_service_cache
from app.models.pricing import ModelPricing, FallbackPricing

from app.utils.model_name_matcher import ModelNameMatcher, ModelMatch

logger = get_logger(name="pricing_service")

class PricingService:
Expand Down Expand Up @@ -88,23 +90,31 @@ async def _fetch_pricing_with_smart_caching(
Fetch pricing with smart caching at multiple levels
"""

# Level 1: Try exact model cache (hot cache)
# Level 1: Try exact model cache
exact_cache_key = f"pricing:exact:{provider_name}:{model_name}"
exact_pricing = await async_provider_service_cache.get(exact_cache_key)

if exact_pricing and PricingService._is_pricing_valid_for_date(exact_pricing, calculation_date):
logger.debug(f"Exact match cache hit for {provider_name}/{model_name}")
return {**exact_pricing, 'source': 'exact_match'}

# Level 2: Try provider fallback cache (warm cache)
# Level 2: Try model fallback cache
model_cache_key = f"pricing:model_fallback:{provider_name}:{model_name}"
model_fallback = await async_provider_service_cache.get(model_cache_key)

if model_fallback and PricingService._is_pricing_valid_for_date(model_fallback, calculation_date):
logger.debug(f"Model fallback cache hit for {provider_name}/{model_name}")
return {**model_fallback, 'source': 'fallback_model'}

# Level 3: Try provider fallback cache
provider_cache_key = f"pricing:provider_fallback:{provider_name}:{model_name}"
provider_fallback = await async_provider_service_cache.get(provider_cache_key)

if provider_fallback and PricingService._is_pricing_valid_for_date(provider_fallback, calculation_date):
logger.debug(f"Provider fallback cache hit for {provider_name}")
return {**provider_fallback, 'source': 'fallback_provider'}

# Level 3: Try global fallback cache (warm cache)
# Level 4: Try global fallback cache
global_cache_key = f"pricing:global_fallback:{provider_name}:{model_name}"
global_fallback = await async_provider_service_cache.get(global_cache_key)

Expand Down Expand Up @@ -142,6 +152,20 @@ async def _fetch_from_database_with_caching(
)
return {**exact_pricing, 'source': 'exact_match'}

# Try model fallback
model_fallback = await PricingService._get_model_fallback_pricing_db(
db, model_name, calculation_date
)

if model_fallback:
# Cache model fallback (warm cache)
cache_key = f"pricing:model_fallback:{provider_name}:{model_name}"
await async_provider_service_cache.set(
cache_key, model_fallback, ttl=PricingService.FALLBACK_CACHE_TTL
)
logger.warning(f"Using model fallback pricing for {provider_name}/{model_name}")
return {**model_fallback, 'source': 'fallback_model'}

# Try provider fallback
provider_fallback = await PricingService._get_provider_fallback_pricing_db(
db, provider_name, calculation_date
Expand Down Expand Up @@ -368,6 +392,7 @@ def _get_cache_ttl(source: str) -> int:
"""Get appropriate TTL based on pricing data source"""
ttl_map = {
'exact_match': PricingService.EXACT_CACHE_TTL,
'fallback_model': PricingService.FALLBACK_CACHE_TTL,
'fallback_provider': PricingService.FALLBACK_CACHE_TTL,
'fallback_global': PricingService.FALLBACK_CACHE_TTL,
'emergency_fallback': PricingService.EMERGENCY_CACHE_TTL,
Expand Down Expand Up @@ -416,27 +441,48 @@ def _calculate_costs_from_pricing(
# Database query methods (only called on cache misses)
@staticmethod
async def _get_exact_model_pricing_db(db: AsyncSession, provider_name: str, model_name: str, calculation_date: datetime) -> Optional[Dict[str, Any]]:
"""Get model pricing from database using longest prefix matching with pure SQL"""
"""Get model pricing from database"""
# First, get available models for this provider, sorted by effective_date DESC
# This ensures we prioritize the most recent models
query = select(ModelPricing.model_name, func.max(ModelPricing.effective_date).label('latest_date')).where(
ModelPricing.provider_name == provider_name,
ModelPricing.effective_date <= calculation_date,
or_(ModelPricing.end_date.is_(None), ModelPricing.end_date > calculation_date)
).group_by(ModelPricing.model_name).order_by(
func.max(ModelPricing.effective_date).desc() # Most recent models first
)

query = select(ModelPricing).where(
result = await db.execute(query)
model_rows = result.fetchall()

if not model_rows:
return None

# Now, find the best match for the input model name
available_models = [row[0] for row in model_rows]
matcher = ModelNameMatcher(available_models)
best_match = matcher.find_best_match(model_name)

if not best_match:
return None

if best_match.match_type != 'exact':
logger.info(f"Provider {provider_name} model match: '{model_name}' -> '{best_match.matched_model}'(type: {best_match.match_type}, confidence: {best_match.confidence:.2f})")


pricing_query = select(ModelPricing).where(
ModelPricing.provider_name == provider_name,
ModelPricing.model_name == best_match.matched_model,
ModelPricing.effective_date <= calculation_date,
or_(ModelPricing.end_date.is_(None), ModelPricing.end_date > calculation_date),
# The input model starts with the stored model name (prefix match)
text(f"'{model_name}' ilike concat(model_name, '%%')")
or_(ModelPricing.end_date.is_(None), ModelPricing.end_date > calculation_date)
).order_by(
# Longest prefix first
func.length(ModelPricing.model_name).desc(),
ModelPricing.effective_date.desc()
).limit(1)
result = await db.execute(query)

result = await db.execute(pricing_query)
pricing = result.scalar_one_or_none()

if pricing:
if pricing.model_name != model_name:
logger.debug(f"Prefix match: '{model_name}' matched with '{pricing.model_name}'")

return {
'input_price': pricing.input_token_price,
'output_price': pricing.output_token_price,
Expand All @@ -448,6 +494,57 @@ async def _get_exact_model_pricing_db(db: AsyncSession, provider_name: str, mode

return None

@staticmethod
async def _get_model_fallback_pricing_db(db: AsyncSession, model_name: str, calculation_date: datetime) -> Optional[Dict[str, Any]]:
"""Get model fallback pricing from database"""
query = select(FallbackPricing.model_name, func.max(FallbackPricing.effective_date).label('latest_date')).where(
FallbackPricing.effective_date <= calculation_date,
FallbackPricing.fallback_type == 'model_default',
or_(FallbackPricing.end_date.is_(None), FallbackPricing.end_date > calculation_date)
).group_by(FallbackPricing.model_name).order_by(
func.max(FallbackPricing.effective_date).desc() # Most recent models first
)

result = await db.execute(query)
model_rows = result.fetchall()

if not model_rows:
return None

# Now, find the best match for the input model name
available_models = [row[0] for row in model_rows]
matcher = ModelNameMatcher(available_models)
best_match = matcher.find_best_match(model_name)

if not best_match:
return None

if best_match.match_type != 'exact':
logger.info(f"Fallback model match: '{model_name}' -> '{best_match.matched_model}'(type: {best_match.match_type}, confidence: {best_match.confidence:.2f})")

pricing_query = select(FallbackPricing).where(
FallbackPricing.model_name == best_match.matched_model,
FallbackPricing.effective_date <= calculation_date,
or_(FallbackPricing.end_date.is_(None), FallbackPricing.end_date > calculation_date)
).order_by(
FallbackPricing.effective_date.desc()
).limit(1)

result = await db.execute(pricing_query)
fallback = result.scalar_one_or_none()

if fallback:
return {
'input_price': fallback.input_token_price,
'output_price': fallback.output_token_price,
'cached_price': fallback.cached_token_price,
'currency': fallback.currency,
'effective_date': fallback.effective_date.isoformat(),
'end_date': fallback.end_date.isoformat() if fallback.end_date else None,
}

return None

@staticmethod
async def _get_provider_fallback_pricing_db(db: AsyncSession, provider_name: str, calculation_date: datetime) -> Optional[Dict[str, Any]]:
"""Get provider fallback pricing from database"""
Expand All @@ -466,7 +563,7 @@ async def _get_provider_fallback_pricing_db(db: AsyncSession, provider_name: str
'input_price': fallback.input_token_price,
'output_price': fallback.output_token_price,
'cached_price': fallback.cached_token_price,
'currency': 'USD',
'currency': fallback.currency,
'effective_date': fallback.effective_date.isoformat(),
'end_date': fallback.end_date.isoformat() if fallback.end_date else None,
}
Expand All @@ -489,7 +586,7 @@ async def _get_global_fallback_pricing_db(db: AsyncSession, calculation_date: da
'input_price': fallback.input_token_price,
'output_price': fallback.output_token_price,
'cached_price': fallback.cached_token_price,
'currency': 'USD',
'currency': fallback.currency,
'effective_date': fallback.effective_date.isoformat(),
'end_date': fallback.end_date.isoformat() if fallback.end_date else None,
}
Expand Down
Loading
Loading