diff --git a/alembic/versions/39bcedfae4fe_add_model_default_pricing.py b/alembic/versions/39bcedfae4fe_add_model_default_pricing.py new file mode 100644 index 0000000..2bcce95 --- /dev/null +++ b/alembic/versions/39bcedfae4fe_add_model_default_pricing.py @@ -0,0 +1,50 @@ +"""add model default pricing + +Revision ID: 39bcedfae4fe +Revises: b206e9a941e3 +Create Date: 2025-08-14 18:31:20.897283 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '39bcedfae4fe' +down_revision = 'b206e9a941e3' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('fallback_pricing', sa.Column('model_name', sa.String(), nullable=True)) + connection = op.get_bind() + connection.execute(sa.text(""" + insert into fallback_pricing (provider_name, model_name, input_token_price, output_token_price, cached_token_price, currency, created_at, updated_at, effective_date, end_date, fallback_type, description) + select distinct on (model_name) + provider_name as provider_name, + model_name as model_name, + input_token_price, + output_token_price, + cached_token_price, + currency, + created_at, + updated_at, + effective_date, + end_date, + 'model_default' as fallback_type, + null as description + from model_pricing + order by model_name, + case when provider_name = 'openai' then 1 + when provider_name = 'anthropic' then 2 + else 3 end + """)) + + +def downgrade() -> None: + connection = op.get_bind() + connection.execute(sa.text(""" + delete from fallback_pricing where fallback_type = 'model_default' + """)) + op.drop_column('fallback_pricing', 'model_name') diff --git a/app/models/pricing.py b/app/models/pricing.py index 83b2158..cd6a897 100644 --- a/app/models/pricing.py +++ b/app/models/pricing.py @@ -45,7 +45,8 @@ class FallbackPricing(BaseModel): __tablename__ = "fallback_pricing" provider_name = Column(String, nullable=True, index=True) # NULL for global fallback - fallback_type = Column(String(20), nullable=False, index=True) # 'provider_default', 'global_default' + model_name = Column(String, nullable=True, index=True) # NULL for global fallback + fallback_type = Column(String(20), nullable=False, index=True) # 'model_default', 'provider_default', 'global_default' # Temporal fields effective_date = Column(DateTime(timezone=True), nullable=False, default=datetime.datetime.now(UTC)) diff --git a/app/services/pricing_service.py b/app/services/pricing_service.py index 66bf611..7c840fe 100644 --- a/app/services/pricing_service.py +++ b/app/services/pricing_service.py @@ -9,6 +9,8 @@ from app.core.async_cache import async_provider_service_cache from app.models.pricing import ModelPricing, FallbackPricing +from app.utils.model_name_matcher import ModelNameMatcher, ModelMatch + logger = get_logger(name="pricing_service") class PricingService: @@ -88,7 +90,7 @@ async def _fetch_pricing_with_smart_caching( Fetch pricing with smart caching at multiple levels """ - # Level 1: Try exact model cache (hot cache) + # Level 1: Try exact model cache exact_cache_key = f"pricing:exact:{provider_name}:{model_name}" exact_pricing = await async_provider_service_cache.get(exact_cache_key) @@ -96,7 +98,15 @@ async def _fetch_pricing_with_smart_caching( logger.debug(f"Exact match cache hit for {provider_name}/{model_name}") return {**exact_pricing, 'source': 'exact_match'} - # Level 2: Try provider fallback cache (warm cache) + # Level 2: Try model fallback cache + model_cache_key = f"pricing:model_fallback:{provider_name}:{model_name}" + model_fallback = await async_provider_service_cache.get(model_cache_key) + + if model_fallback and PricingService._is_pricing_valid_for_date(model_fallback, calculation_date): + logger.debug(f"Model fallback cache hit for {provider_name}/{model_name}") + return {**model_fallback, 'source': 'fallback_model'} + + # Level 3: Try provider fallback cache provider_cache_key = f"pricing:provider_fallback:{provider_name}:{model_name}" provider_fallback = await async_provider_service_cache.get(provider_cache_key) @@ -104,7 +114,7 @@ async def _fetch_pricing_with_smart_caching( logger.debug(f"Provider fallback cache hit for {provider_name}") return {**provider_fallback, 'source': 'fallback_provider'} - # Level 3: Try global fallback cache (warm cache) + # Level 4: Try global fallback cache global_cache_key = f"pricing:global_fallback:{provider_name}:{model_name}" global_fallback = await async_provider_service_cache.get(global_cache_key) @@ -142,6 +152,20 @@ async def _fetch_from_database_with_caching( ) return {**exact_pricing, 'source': 'exact_match'} + # Try model fallback + model_fallback = await PricingService._get_model_fallback_pricing_db( + db, model_name, calculation_date + ) + + if model_fallback: + # Cache model fallback (warm cache) + cache_key = f"pricing:model_fallback:{provider_name}:{model_name}" + await async_provider_service_cache.set( + cache_key, model_fallback, ttl=PricingService.FALLBACK_CACHE_TTL + ) + logger.warning(f"Using model fallback pricing for {provider_name}/{model_name}") + return {**model_fallback, 'source': 'fallback_model'} + # Try provider fallback provider_fallback = await PricingService._get_provider_fallback_pricing_db( db, provider_name, calculation_date @@ -368,6 +392,7 @@ def _get_cache_ttl(source: str) -> int: """Get appropriate TTL based on pricing data source""" ttl_map = { 'exact_match': PricingService.EXACT_CACHE_TTL, + 'fallback_model': PricingService.FALLBACK_CACHE_TTL, 'fallback_provider': PricingService.FALLBACK_CACHE_TTL, 'fallback_global': PricingService.FALLBACK_CACHE_TTL, 'emergency_fallback': PricingService.EMERGENCY_CACHE_TTL, @@ -416,27 +441,48 @@ def _calculate_costs_from_pricing( # Database query methods (only called on cache misses) @staticmethod async def _get_exact_model_pricing_db(db: AsyncSession, provider_name: str, model_name: str, calculation_date: datetime) -> Optional[Dict[str, Any]]: - """Get model pricing from database using longest prefix matching with pure SQL""" + """Get model pricing from database""" + # First, get available models for this provider, sorted by effective_date DESC + # This ensures we prioritize the most recent models + query = select(ModelPricing.model_name, func.max(ModelPricing.effective_date).label('latest_date')).where( + ModelPricing.provider_name == provider_name, + ModelPricing.effective_date <= calculation_date, + or_(ModelPricing.end_date.is_(None), ModelPricing.end_date > calculation_date) + ).group_by(ModelPricing.model_name).order_by( + func.max(ModelPricing.effective_date).desc() # Most recent models first + ) - query = select(ModelPricing).where( + result = await db.execute(query) + model_rows = result.fetchall() + + if not model_rows: + return None + + # Now, find the best match for the input model name + available_models = [row[0] for row in model_rows] + matcher = ModelNameMatcher(available_models) + best_match = matcher.find_best_match(model_name) + + if not best_match: + return None + + if best_match.match_type != 'exact': + logger.info(f"Provider {provider_name} model match: '{model_name}' -> '{best_match.matched_model}'(type: {best_match.match_type}, confidence: {best_match.confidence:.2f})") + + + pricing_query = select(ModelPricing).where( ModelPricing.provider_name == provider_name, + ModelPricing.model_name == best_match.matched_model, ModelPricing.effective_date <= calculation_date, - or_(ModelPricing.end_date.is_(None), ModelPricing.end_date > calculation_date), - # The input model starts with the stored model name (prefix match) - text(f"'{model_name}' ilike concat(model_name, '%%')") + or_(ModelPricing.end_date.is_(None), ModelPricing.end_date > calculation_date) ).order_by( - # Longest prefix first - func.length(ModelPricing.model_name).desc(), ModelPricing.effective_date.desc() ).limit(1) - - result = await db.execute(query) + + result = await db.execute(pricing_query) pricing = result.scalar_one_or_none() if pricing: - if pricing.model_name != model_name: - logger.debug(f"Prefix match: '{model_name}' matched with '{pricing.model_name}'") - return { 'input_price': pricing.input_token_price, 'output_price': pricing.output_token_price, @@ -448,6 +494,57 @@ async def _get_exact_model_pricing_db(db: AsyncSession, provider_name: str, mode return None + @staticmethod + async def _get_model_fallback_pricing_db(db: AsyncSession, model_name: str, calculation_date: datetime) -> Optional[Dict[str, Any]]: + """Get model fallback pricing from database""" + query = select(FallbackPricing.model_name, func.max(FallbackPricing.effective_date).label('latest_date')).where( + FallbackPricing.effective_date <= calculation_date, + FallbackPricing.fallback_type == 'model_default', + or_(FallbackPricing.end_date.is_(None), FallbackPricing.end_date > calculation_date) + ).group_by(FallbackPricing.model_name).order_by( + func.max(FallbackPricing.effective_date).desc() # Most recent models first + ) + + result = await db.execute(query) + model_rows = result.fetchall() + + if not model_rows: + return None + + # Now, find the best match for the input model name + available_models = [row[0] for row in model_rows] + matcher = ModelNameMatcher(available_models) + best_match = matcher.find_best_match(model_name) + + if not best_match: + return None + + if best_match.match_type != 'exact': + logger.info(f"Fallback model match: '{model_name}' -> '{best_match.matched_model}'(type: {best_match.match_type}, confidence: {best_match.confidence:.2f})") + + pricing_query = select(FallbackPricing).where( + FallbackPricing.model_name == best_match.matched_model, + FallbackPricing.effective_date <= calculation_date, + or_(FallbackPricing.end_date.is_(None), FallbackPricing.end_date > calculation_date) + ).order_by( + FallbackPricing.effective_date.desc() + ).limit(1) + + result = await db.execute(pricing_query) + fallback = result.scalar_one_or_none() + + if fallback: + return { + 'input_price': fallback.input_token_price, + 'output_price': fallback.output_token_price, + 'cached_price': fallback.cached_token_price, + 'currency': fallback.currency, + 'effective_date': fallback.effective_date.isoformat(), + 'end_date': fallback.end_date.isoformat() if fallback.end_date else None, + } + + return None + @staticmethod async def _get_provider_fallback_pricing_db(db: AsyncSession, provider_name: str, calculation_date: datetime) -> Optional[Dict[str, Any]]: """Get provider fallback pricing from database""" @@ -466,7 +563,7 @@ async def _get_provider_fallback_pricing_db(db: AsyncSession, provider_name: str 'input_price': fallback.input_token_price, 'output_price': fallback.output_token_price, 'cached_price': fallback.cached_token_price, - 'currency': 'USD', + 'currency': fallback.currency, 'effective_date': fallback.effective_date.isoformat(), 'end_date': fallback.end_date.isoformat() if fallback.end_date else None, } @@ -489,7 +586,7 @@ async def _get_global_fallback_pricing_db(db: AsyncSession, calculation_date: da 'input_price': fallback.input_token_price, 'output_price': fallback.output_token_price, 'cached_price': fallback.cached_token_price, - 'currency': 'USD', + 'currency': fallback.currency, 'effective_date': fallback.effective_date.isoformat(), 'end_date': fallback.end_date.isoformat() if fallback.end_date else None, } diff --git a/app/utils/model_name_matcher.py b/app/utils/model_name_matcher.py new file mode 100644 index 0000000..e0cc4ac --- /dev/null +++ b/app/utils/model_name_matcher.py @@ -0,0 +1,259 @@ +import re +from typing import Dict, List, Tuple, Optional +from dataclasses import dataclass +from difflib import SequenceMatcher + +@dataclass +class ModelMatch: + """Represents a model match with confidence score""" + matched_model: str + confidence: float + match_type: str # 'exact', 'normalized', 'prefix', 'fuzzy' + normalized_query: str + normalized_match: str + +class ModelNameMatcher: + """ + Advanced model name matching algorithm with support for: + - Date format conversion (@YYYYMMDD <-> -YYYY-MM-DD) + - Prefix matching with longest match priority + - Fuzzy matching for typos and variations + - Model family grouping and versioning + """ + + def __init__(self, available_models: List[str]): + self.available_models = available_models + self.normalized_models = {self._normalize_model_name(model): model for model in available_models} + self.model_families = self._build_model_families() + + def _normalize_model_name(self, model_name: str) -> str: + """ + Normalize model names for better matching: + - Convert @ date separators to - + - Standardize date formats + - Remove extra spaces/hyphens + """ + normalized = model_name.lower().strip() + + # Handle @ date format conversion: gpt-4.1-mini@2025-04-14 -> gpt-4.1-mini-2025-04-14 + if '@' in normalized: + parts = normalized.split('@') + if len(parts) == 2: + base_model, date_part = parts + # Convert YYYYMMDD to YYYY-MM-DD if needed + date_part = self._normalize_date_format(date_part) + normalized = f"{base_model}-{date_part}" + + # Standardize separators + normalized = re.sub(r'[_\s]+', '-', normalized) + normalized = re.sub(r'-+', '-', normalized) # Remove duplicate hyphens + + return normalized + + def _normalize_date_format(self, date_str: str) -> str: + """Convert various date formats to YYYY-MM-DD""" + date_str = date_str.strip() + + # YYYYMMDD -> YYYY-MM-DD + if re.match(r'^\d{8}$', date_str): + return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}" + + # YYYY-MM-DD (already correct) + if re.match(r'^\d{4}-\d{2}-\d{2}$', date_str): + return date_str + + # YYYY/MM/DD -> YYYY-MM-DD + if re.match(r'^\d{4}/\d{2}/\d{2}$', date_str): + return date_str.replace('/', '-') + + return date_str # Return as-is if no pattern matches + + def _build_model_families(self) -> Dict[str, List[str]]: + """Group models by family for better matching""" + families = {} + + for model in self.available_models: + # Extract base model name (before version/date info) + base = re.split(r'[-_]\d', model)[0] # Split at first digit after separator + base = re.sub(r'-(latest|preview)$', '', base) # Remove common suffixes + + if base not in families: + families[base] = [] + families[base].append(model) + + return families + + def find_best_match(self, query_model: str, min_confidence: float = 0.6) -> Optional[ModelMatch]: + """ + Find the best matching model using multiple strategies: + 1. Exact match + 2. Normalized exact match + 3. Prefix matching (longest first) + 4. Family-based fuzzy matching + """ + + # Strategy 1: Exact match + if query_model in self.available_models: + return ModelMatch( + matched_model=query_model, + confidence=1.0, + match_type='exact', + normalized_query=query_model, + normalized_match=query_model + ) + + normalized_query = self._normalize_model_name(query_model) + + # Strategy 2: Normalized exact match + if normalized_query in self.normalized_models: + matched_model = self.normalized_models[normalized_query] + return ModelMatch( + matched_model=matched_model, + confidence=0.95, + match_type='normalized', + normalized_query=normalized_query, + normalized_match=normalized_query + ) + + # Strategy 3: Prefix matching (longest match first) + prefix_matches = self._find_prefix_matches(normalized_query) + if prefix_matches: + best_prefix = max(prefix_matches, key=lambda x: (len(x[1]), x[2])) # Longest match, highest confidence + return ModelMatch( + matched_model=best_prefix[0], + confidence=best_prefix[2], + match_type='prefix', + normalized_query=normalized_query, + normalized_match=best_prefix[1] + ) + + # Strategy 4: Fuzzy matching within model families + fuzzy_match = self._find_fuzzy_match(normalized_query, min_confidence) + if fuzzy_match: + return fuzzy_match + + return None + + def _find_prefix_matches(self, normalized_query: str) -> List[Tuple[str, str, float]]: + """Find models that match as prefixes, with confidence scoring""" + matches = [] + + for normalized_model, original_model in self.normalized_models.items(): + # Check if query starts with stored model (database models are prefixes) + if normalized_query.startswith(normalized_model): + prefix_len = len(normalized_model) + query_len = len(normalized_query) + + # Calculate confidence based on how much of the query is matched + confidence = min(0.9, prefix_len / query_len * 0.9) + + # Bonus for exact word boundaries + if query_len == prefix_len or normalized_query[prefix_len] in '-_.': + confidence += 0.05 + + matches.append((original_model, normalized_model, confidence)) + + # Also check reverse: if stored model starts with query (query is prefix of stored model) + elif normalized_model.startswith(normalized_query): + prefix_len = len(normalized_query) + model_len = len(normalized_model) + + # Lower confidence for partial matches + confidence = min(0.85, prefix_len / model_len * 0.8) + + # Bonus for exact word boundaries + if model_len == prefix_len or normalized_model[prefix_len] in '-_.': + confidence += 0.05 + + matches.append((original_model, normalized_model, confidence)) + + return matches + + def _find_fuzzy_match(self, normalized_query: str, min_confidence: float) -> Optional[ModelMatch]: + """Find best fuzzy match using sequence similarity""" + best_match = None + best_confidence = 0.0 + + # First try within same model family + query_base = re.split(r'[-_]\d', normalized_query)[0] + + if query_base in self.model_families: + for candidate in self.model_families[query_base]: + normalized_candidate = self._normalize_model_name(candidate) + confidence = SequenceMatcher(None, normalized_query, normalized_candidate).ratio() + + if confidence > best_confidence and confidence >= min_confidence: + best_confidence = confidence + best_match = ModelMatch( + matched_model=candidate, + confidence=confidence, + match_type='fuzzy', + normalized_query=normalized_query, + normalized_match=normalized_candidate + ) + + # If no good family match, try all models with higher threshold + if not best_match: + high_threshold = max(min_confidence, 0.8) + for normalized_model, original_model in self.normalized_models.items(): + confidence = SequenceMatcher(None, normalized_query, normalized_model).ratio() + + if confidence > best_confidence and confidence >= high_threshold: + best_confidence = confidence + best_match = ModelMatch( + matched_model=original_model, + confidence=confidence, + match_type='fuzzy', + normalized_query=normalized_query, + normalized_match=normalized_model + ) + + return best_match + + def find_all_matches(self, query_model: str, min_confidence: float = 0.6, limit: int = 5) -> List[ModelMatch]: + """Find all possible matches sorted by confidence""" + matches = [] + normalized_query = self._normalize_model_name(query_model) + + # Check exact match first + if query_model in self.available_models: + matches.append(ModelMatch( + matched_model=query_model, + confidence=1.0, + match_type='exact', + normalized_query=query_model, + normalized_match=query_model + )) + return matches + + # Find all prefix matches + prefix_matches = self._find_prefix_matches(normalized_query) + for original_model, normalized_model, confidence in prefix_matches: + if confidence >= min_confidence: + matches.append(ModelMatch( + matched_model=original_model, + confidence=confidence, + match_type='prefix', + normalized_query=normalized_query, + normalized_match=normalized_model + )) + + # Find fuzzy matches if we don't have enough good matches + if len(matches) < limit: + for normalized_model, original_model in self.normalized_models.items(): + confidence = SequenceMatcher(None, normalized_query, normalized_model).ratio() + + if confidence >= min_confidence: + # Check if we already have this match + if not any(m.matched_model == original_model for m in matches): + matches.append(ModelMatch( + matched_model=original_model, + confidence=confidence, + match_type='fuzzy', + normalized_query=normalized_query, + normalized_match=normalized_model + )) + + # Sort by confidence (descending) and return top matches + matches.sort(key=lambda x: x.confidence, reverse=True) + return matches[:limit]