From 929837e8001d4c769f24f7cafa0f6769301aba59 Mon Sep 17 00:00:00 2001 From: Wenjing Yu Date: Sat, 9 Aug 2025 18:52:11 -0700 Subject: [PATCH] Improve get provider and model name function --- app/services/provider_service.py | 236 +++++++++++++++----- tests/unit_tests/test_provider_service.py | 260 +++++++++++++++++----- 2 files changed, 386 insertions(+), 110 deletions(-) diff --git a/app/services/provider_service.py b/app/services/provider_service.py index ca0c256..7df1f37 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -13,7 +13,11 @@ from app.core.async_cache import async_provider_service_cache, DEBUG_CACHE from app.core.logger import get_logger from app.core.security import decrypt_api_key, encrypt_api_key -from app.exceptions.exceptions import InvalidProviderException, BaseInvalidRequestException, InvalidForgeKeyException +from app.exceptions.exceptions import ( + InvalidProviderException, + BaseInvalidRequestException, + InvalidForgeKeyException, +) from app.models.user import User from app.core.database import get_db_session @@ -26,9 +30,16 @@ # Add constants at the top of the file, after imports MODEL_PARTS_MIN_LENGTH = 2 # Minimum number of parts in a model name (e.g., "gpt-4") + # Create a background task to update the usage tracker that won't be cancelled # Even if the streaming response is cancelled by client disconnect -async def update_usage_in_background(usage_tracker_id: uuid.UUID, input_tokens: int, output_tokens: int, cached_tokens: int, reasoning_tokens: int): +async def update_usage_in_background( + usage_tracker_id: uuid.UUID, + input_tokens: int, + output_tokens: int, + cached_tokens: int, + reasoning_tokens: int, +): # Use a fresh DB session for logging, since the original request session # may have been closed by FastAPI after the response was returned. @@ -42,6 +53,7 @@ async def update_usage_in_background(usage_tracker_id: uuid.UUID, input_tokens: reasoning_tokens=reasoning_tokens, ) + class ProviderService: """Service for handling provider API calls. @@ -86,7 +98,9 @@ def __init__(self, user_id: int, db: AsyncSession, api_key_id: int | None = None self._keys_loaded = False @classmethod - async def async_get_instance(cls, user: User, db: AsyncSession, api_key_id: int | None = None) -> "ProviderService": + async def async_get_instance( + cls, user: User, db: AsyncSession, api_key_id: int | None = None + ) -> "ProviderService": """Get a cached instance of ProviderService for a user or create a new one (async version)""" cache_key = f"provider_service:{user.id}" @@ -182,8 +196,12 @@ async def _load_provider_keys(self) -> dict[str, dict[str, Any]]: # Query ProviderKey directly by user_id from app.models.provider_key import ProviderKey - result = await self.db.execute(select(ProviderKey).filter(ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None)) - provider_key_records = result.scalars().all() + result = await self.db.execute( + select(ProviderKey).filter( + ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None + ) + ) + provider_key_records = result.scalars().all() keys = {} for provider_key in provider_key_records: @@ -204,7 +222,9 @@ async def _load_provider_keys(self) -> dict[str, dict[str, Any]]: logger.debug( f"Caching provider keys for user {self.user_id} (TTL: 3600s) (sync)" ) - await async_provider_service_cache.set(cache_key, keys, ttl=3600) # Cache for 1 hour + await async_provider_service_cache.set( + cache_key, keys, ttl=3600 + ) # Cache for 1 hour return keys @@ -233,7 +253,11 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]: # Query ProviderKey directly by user_id from app.models.provider_key import ProviderKey - result = await self.db.execute(select(ProviderKey).filter(ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None)) + result = await self.db.execute( + select(ProviderKey).filter( + ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None + ) + ) provider_key_records = result.scalars().all() keys = {} @@ -261,36 +285,79 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]: return keys - def _extract_provider_name_prefix(self, model: str) -> tuple[str | None, str]: + def _extract_provider_name_prefix( + self, model: str, allowed_provider_names: list[str] | set[str] | None = None + ) -> tuple[str | None, str]: """ Extracts a provider name prefix from the model name if it exists. Returns (provider_prefix, model_name_without_prefix) + + Args: + model: The model name to extract provider prefix from + allowed_provider_names: Optional list/set of allowed provider names for security scope. + If provided, only checks against these providers (which are already + the intersection of user's provider keys and API key scope). + If None, checks against user's provider keys. """ - all_provider_names = { - p.lower() for p in ProviderAdapterFactory.get_all_adapters().keys() - } - model_parts = model.split("/") - - # Find the longest matching provider name from the start of the model string - for i in range(len(model_parts), 0, -1): - potential_provider = "/".join(model_parts[:i]).lower() - if potential_provider in all_provider_names: - # If the provider name is the entire model string, and it has only one part, - # then we should treat it as a model name, not a provider prefix. - # e.g. model name is "openai", we shouldn't parse it as provider "openai" and empty model. - is_entire_model_string = i == len(model_parts) - if is_entire_model_string and len(model_parts) == 1: - continue # Skip, treat as model name - - provider_name = potential_provider - model_name_without_prefix = "/".join(model_parts[i:]) - return provider_name, model_name_without_prefix + # Determine which provider names to check against + if not self._keys_loaded: + return None, model + + # Get user's available provider names + user_provider_names = {p.lower() for p in self.provider_keys.keys()} + + if allowed_provider_names: + # allowed_provider_names is already the intersection of user's provider keys and API key scope + # (determined by the forge_api_key_provider_scope_association table) + allowed_providers_lower = {p.lower() for p in allowed_provider_names} + + # Use the optimized approach with allowed providers + model_lower = model.lower() + + # Sort providers by length (longest first) to avoid substring conflicts + # e.g., "openai-custom" should match before "openai" + sorted_providers = sorted(allowed_providers_lower, key=len, reverse=True) + + for provider in sorted_providers: + # Check if provider is at the start of the model string + if model_lower.startswith(provider + "/"): + # Find the original case from allowed_provider_names + original_provider = next( + p for p in allowed_provider_names if p.lower() == provider + ) + + # Extract the model name without prefix + prefix_length = len(provider) + 1 # +1 for the "/" + model_name_without_prefix = model[prefix_length:] + + # Return the provider name in lowercase to match the provider keys + return original_provider.lower(), model_name_without_prefix + else: + # Use the comprehensive approach checking user's provider keys + all_provider_names = user_provider_names + + model_parts = model.split("/") + + # Find the longest matching provider name from the start of the model string + for i in range(len(model_parts), 0, -1): + potential_provider = "/".join(model_parts[:i]).lower() + if potential_provider in all_provider_names: + # If the provider name is the entire model string, and it has only one part, + # then we should treat it as a model name, not a provider prefix. + # e.g. model name is "openai", we shouldn't parse it as provider "openai" and empty model. + is_entire_model_string = i == len(model_parts) + if is_entire_model_string and len(model_parts) == 1: + continue # Skip, treat as model name + + provider_name = potential_provider + model_name_without_prefix = "/".join(model_parts[i:]) + return provider_name, model_name_without_prefix return None, model def _get_provider_info_with_prefix( self, provider_name: str, model_name: str, original_model: str - ) -> tuple[str, str, str | None]: + ) -> tuple[str, str, str | None, int | None]: """Handles provider lookup when a prefix is found in the model name.""" matching_provider = next( (key for key in self.provider_keys.keys() if key.lower() == provider_name), @@ -315,7 +382,7 @@ def _get_provider_info_with_prefix( def _find_provider_for_unprefixed_model( self, model: str - ) -> tuple[str, str, str | None]: + ) -> tuple[str, str, str | None, int | None]: """Finds a provider for a model that doesn't have a provider prefix.""" # Prioritize providers whose names are substrings of the model, e.g., "gemini" in "models/gemini-2.0-flash" # This helps resolve ambiguity when multiple providers might claim to support a model. @@ -342,16 +409,22 @@ def _find_provider_for_unprefixed_model( logger.error(f"No matching provider found for {model}") raise InvalidProviderException(model) - def _get_provider_info(self, model: str) -> tuple[str, str, str | None]: + def _get_provider_info( + self, model: str, allowed_provider_names: list[str] | set[str] | None = None + ) -> tuple[str, str, str | None, int | None]: """ Determine the provider based on the model name. + If allowed_provider_names is provided, use optimized lookup for faster performance. """ if not self._keys_loaded: error_message = "Provider keys must be loaded before calling _get_provider_info. Call _load_provider_keys_async() first." logger.error(error_message) raise RuntimeError(error_message) - provider_name, model_name_no_prefix = self._extract_provider_name_prefix(model) + # Use the unified provider name extraction method + provider_name, model_name_no_prefix = self._extract_provider_name_prefix( + model, allowed_provider_names + ) if provider_name: return self._get_provider_info_with_prefix( @@ -426,7 +499,9 @@ async def _list_models_helper( return provider_models except Exception as e: # Use parameterized logging to avoid issues if the error message contains braces - logger.error("Error fetching models for {}: {}", provider_name, str(e)) + logger.error( + "Error fetching models for {}: {}", provider_name, str(e) + ) return [] provider_adapter_cls = ProviderAdapterFactory.get_adapter_cls(provider_name) @@ -470,17 +545,20 @@ async def process_request( error_message = "Model is required" logger.error(error_message) raise BaseInvalidRequestException( - provider_name="unknown", - error=ValueError(error_message) + provider_name="unknown", error=ValueError(error_message) ) - provider_name, actual_model, base_url, provider_key_id = self._get_provider_info(model) + provider_name, actual_model, base_url, provider_key_id = ( + self._get_provider_info(model, allowed_provider_names) + ) # Enforce scope restriction (case-insensitive). if allowed_provider_names is not None and ( provider_name.lower() not in {p.lower() for p in allowed_provider_names} ): - error_message = f"API key is not permitted to use provider '{provider_name}'." + error_message = ( + f"API key is not permitted to use provider '{provider_name}'." + ) logger.error(error_message) raise InvalidForgeKeyException(error=ValueError(error_message)) @@ -493,7 +571,9 @@ async def process_request( # Get the provider's API key if provider_name not in self.provider_keys: - error_message = f"API key is not permitted to use provider '{provider_name}'." + error_message = ( + f"API key is not permitted to use provider '{provider_name}'." + ) logger.error(error_message) raise InvalidForgeKeyException(error=ValueError(error_message)) @@ -520,8 +600,12 @@ async def process_request( else: # TODO: this shouldn't happen, but we handle it gracefully as we don't want to break the flow # Dive deeper into this if it ever happens - logger.info(f"api_key_id: {self.api_key_id}, provider_key_id: {provider_key_id}") - logger.warning("No API key ID or provider key ID found, skipping usage tracking") + logger.info( + f"api_key_id: {self.api_key_id}, provider_key_id: {provider_key_id}" + ) + logger.warning( + "No API key ID or provider key ID found, skipping usage tracking" + ) if "completion" in endpoint: result = await adapter.process_completion( @@ -532,7 +616,9 @@ async def process_request( elif "images/generations" in endpoint: # TODO: we only support openai for now if provider_name != "openai": - error_message = f"Unsupported endpoint: {endpoint} for provider {provider_name}" + error_message = ( + f"Unsupported endpoint: {endpoint} for provider {provider_name}" + ) logger.error(error_message) raise NotImplementedError(error_message) result = await adapter.process_image_generation( @@ -543,7 +629,9 @@ async def process_request( elif "images/edits" in endpoint: # TODO: we only support openai for now if provider_name != "openai": - error_message = f"Unsupported endpoint: {endpoint} for provider {provider_name}" + error_message = ( + f"Unsupported endpoint: {endpoint} for provider {provider_name}" + ) logger.error(error_message) raise NotImplementedError(error_message) result = await adapter.process_image_edits( @@ -580,10 +668,22 @@ async def process_request( usage = result.get("usage", {}) input_tokens = usage.get("prompt_tokens", 0) output_tokens = usage.get("completion_tokens", 0) - cached_tokens = usage.get("prompt_tokens_details", {}).get("cached_tokens", 0) - reasoning_tokens = usage.get("completion_tokens_details", {}).get("reasoning_tokens", 0) + cached_tokens = usage.get("prompt_tokens_details", {}).get( + "cached_tokens", 0 + ) + reasoning_tokens = usage.get("completion_tokens_details", {}).get( + "reasoning_tokens", 0 + ) - asyncio.create_task(update_usage_in_background(usage_tracker_id, input_tokens, output_tokens, cached_tokens, reasoning_tokens)) + asyncio.create_task( + update_usage_in_background( + usage_tracker_id, + input_tokens, + output_tokens, + cached_tokens, + reasoning_tokens, + ) + ) return result else: # For streaming responses, wrap the generator to count tokens @@ -637,14 +737,24 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: f"Found usage data in chunk: {data['usage']}" ) usage = data.get("usage", {}) - input_tokens += usage.get( - "prompt_tokens", 0 - ) or 0 - output_tokens += usage.get( - "completion_tokens", 0 - ) or 0 - cached_tokens += usage.get("prompt_tokens_details", {}).get("cached_tokens", 0) or 0 - reasoning_tokens += usage.get("completion_tokens_details", {}).get("reasoning_tokens", 0) or 0 + input_tokens += ( + usage.get("prompt_tokens", 0) or 0 + ) + output_tokens += ( + usage.get("completion_tokens", 0) or 0 + ) + cached_tokens += ( + usage.get("prompt_tokens_details", {}).get( + "cached_tokens", 0 + ) + or 0 + ) + reasoning_tokens += ( + usage.get( + "completion_tokens_details", {} + ).get("reasoning_tokens", 0) + or 0 + ) # Extract content from the chunk based on OpenAI format if "choices" in data: @@ -657,7 +767,9 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: # Only count tokens if we don't have final usage data if content: # Count tokens in content (approx) - approximate_output_tokens += len(content) // 4 + approximate_output_tokens += ( + len(content) // 4 + ) except json.JSONDecodeError: # If JSON parsing fails, just continue pass @@ -671,7 +783,9 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: except Exception as e: # Use parameterized logging to avoid issues if the error message contains braces - logger.error("Error in streaming response: {}", str(e), exc_info=True) + logger.error( + "Error in streaming response: {}", str(e), exc_info=True + ) # Re-raise to propagate the error raise finally: @@ -681,12 +795,22 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: f"output_tokens={output_tokens}, cached_tokens={cached_tokens}, reasoning_tokens={reasoning_tokens}" ) - asyncio.create_task(update_usage_in_background(usage_tracker_id, input_tokens or approximate_input_tokens, output_tokens or approximate_output_tokens, cached_tokens, reasoning_tokens)) + asyncio.create_task( + update_usage_in_background( + usage_tracker_id, + input_tokens or approximate_input_tokens, + output_tokens or approximate_output_tokens, + cached_tokens, + reasoning_tokens, + ) + ) return token_counting_stream() -async def create_default_tensorblock_provider_for_user(user_id: int, db: AsyncSession) -> None: +async def create_default_tensorblock_provider_for_user( + user_id: int, db: AsyncSession +) -> None: """ Create a default TensorBlock provider key for a new user. This allows users to use Forge immediately without binding their own API keys. diff --git a/tests/unit_tests/test_provider_service.py b/tests/unit_tests/test_provider_service.py index 156b953..1c86b3f 100644 --- a/tests/unit_tests/test_provider_service.py +++ b/tests/unit_tests/test_provider_service.py @@ -50,7 +50,9 @@ async def asyncSetUp(self): self.provider_key_google.provider_name = "gemini" self.provider_key_google.encrypted_api_key = "encrypted_gemini_key" self.provider_key_google.base_url = None - self.provider_key_google.model_mapping = {"test-gemini": "models/gemini-2.0-flash"} + self.provider_key_google.model_mapping = { + "test-gemini": "models/gemini-2.0-flash" + } self.provider_key_xai = MagicMock(spec=ProviderKey) self.provider_key_xai.provider_name = "xai" @@ -62,7 +64,9 @@ async def asyncSetUp(self): self.provider_key_fireworks.provider_name = "fireworks" self.provider_key_fireworks.encrypted_api_key = "encrypted_fireworks_key" self.provider_key_fireworks.base_url = None - self.provider_key_fireworks.model_mapping = {"test-fireworks": "accounts/fireworks/models/code-llama-7b"} + self.provider_key_fireworks.model_mapping = { + "test-fireworks": "accounts/fireworks/models/code-llama-7b" + } self.provider_key_openrouter = MagicMock(spec=ProviderKey) self.provider_key_openrouter.provider_name = "openrouter" @@ -86,7 +90,15 @@ async def asyncSetUp(self): self.provider_key_bedrock.provider_name = "bedrock" self.provider_key_bedrock.encrypted_api_key = "encrypted_bedrock_key" self.provider_key_bedrock.base_url = None - self.provider_key_bedrock.model_mapping = {"test-bedrock": "claude-3-5-sonnet-20240620-v1:0"} + self.provider_key_bedrock.model_mapping = { + "test-bedrock": "claude-3-5-sonnet-20240620-v1:0" + } + + self.provider_key_custom = MagicMock(spec=ProviderKey) + self.provider_key_custom.provider_name = "custom" + self.provider_key_custom.encrypted_api_key = "encrypted_custom_key" + self.provider_key_custom.base_url = None + self.provider_key_custom.model_mapping = {"xxxx": "custom-model"} self.user.provider_keys = [ self.provider_key_openai, @@ -117,16 +129,21 @@ async def asyncSetUp(self): "encrypted_fireworks_key": "decrypted_fireworks_key", "encrypted_openrouter_key": "decrypted_openrouter_key", "encrypted_together_key": "decrypted_together_key", - "encrypted_azure_key": json.dumps({ - "api_key": "decrypted_azure_key", - "api_version": "2025-01-01-preview", - }), - "encrypted_bedrock_key": json.dumps({ - "api_key": "decrypted_bedrock_key", - "region_name": "us-east-1", - "aws_access_key_id": "decrypted_aws_access_key_id", - "aws_secret_access_key": "decrypted_aws_secret_access_key", - }), + "encrypted_azure_key": json.dumps( + { + "api_key": "decrypted_azure_key", + "api_version": "2025-01-01-preview", + } + ), + "encrypted_bedrock_key": json.dumps( + { + "api_key": "decrypted_bedrock_key", + "region_name": "us-east-1", + "aws_access_key_id": "decrypted_aws_access_key_id", + "aws_secret_access_key": "decrypted_aws_secret_access_key", + } + ), + "encrypted_custom_key": "decrypted_custom_key", } mock_decrypt.side_effect = lambda key: decrypt_key_map[key] @@ -147,9 +164,14 @@ async def asyncSetUp(self): self.provider_key_together, self.provider_key_azure, self.provider_key_bedrock, + self.provider_key_custom, ] - mock_result.scalars.return_value = mock_scalars # scalars() returns sync object - self.db.execute = AsyncMock(return_value=mock_result) # Only execute() is async + mock_result.scalars.return_value = ( + mock_scalars # scalars() returns sync object + ) + self.db.execute = AsyncMock( + return_value=mock_result + ) # Only execute() is async self.service = ProviderService(self.user.id, self.db) @@ -177,16 +199,26 @@ async def test_load_provider_keys(self): self.assertEqual(keys["fireworks"]["api_key"], "decrypted_fireworks_key") self.assertEqual(keys["openrouter"]["api_key"], "decrypted_openrouter_key") self.assertEqual(keys["together"]["api_key"], "decrypted_together_key") - self.assertEqual(keys["azure"]["api_key"], json.dumps({ - "api_key": "decrypted_azure_key", - "api_version": "2025-01-01-preview", - })) - self.assertEqual(keys["bedrock"]["api_key"], json.dumps({ - "api_key": "decrypted_bedrock_key", - "region_name": "us-east-1", - "aws_access_key_id": "decrypted_aws_access_key_id", - "aws_secret_access_key": "decrypted_aws_secret_access_key", - })) + self.assertEqual( + keys["azure"]["api_key"], + json.dumps( + { + "api_key": "decrypted_azure_key", + "api_version": "2025-01-01-preview", + } + ), + ) + self.assertEqual( + keys["bedrock"]["api_key"], + json.dumps( + { + "api_key": "decrypted_bedrock_key", + "region_name": "us-east-1", + "aws_access_key_id": "decrypted_aws_access_key_id", + "aws_secret_access_key": "decrypted_aws_secret_access_key", + } + ), + ) self.assertEqual(keys["openai"]["model_mapping"], {"custom-gpt": "gpt-4"}) self.assertEqual( keys["gemini"]["model_mapping"], {"test-gemini": "models/gemini-2.0-flash"} @@ -195,14 +227,18 @@ async def test_load_provider_keys(self): async def test_get_provider_info_explicit_mapping(self): """Test getting provider info with an explicitly mapped model""" # Since keys are already loaded in setUp, _get_provider_info should work directly - provider, model, base_url, provider_key_id = self.service._get_provider_info("custom-gpt") + provider, model, base_url, provider_key_id = self.service._get_provider_info( + "custom-gpt" + ) self.assertEqual(provider, "openai") self.assertEqual(model, "gpt-4") self.assertIsNone(base_url) self.assertEqual(provider_key_id, self.provider_key_openai.id) - provider, model, base_url, provider_key_id = self.service._get_provider_info("test-gemini") + provider, model, base_url, provider_key_id = self.service._get_provider_info( + "test-gemini" + ) self.assertEqual(provider, "gemini") self.assertEqual(model, "models/gemini-2.0-flash") @@ -233,7 +269,9 @@ async def test_get_provider_info_prefix_matching(self): self.assertEqual(provider_key_id, self.provider_key_google.id) # Test XAI prefix - provider, model, base_url, provider_key_id = self.service._get_provider_info("xai/grok-2-1212") + provider, model, base_url, provider_key_id = self.service._get_provider_info( + "xai/grok-2-1212" + ) self.assertEqual(provider, "xai") self.assertEqual(provider_key_id, self.provider_key_xai.id) @@ -258,14 +296,103 @@ async def test_get_provider_info_prefix_matching(self): self.assertEqual(provider, "together") self.assertEqual(provider_key_id, self.provider_key_together.id) - provider, model, base_url, provider_key_id = self.service._get_provider_info("azure/gpt-4o") + provider, model, base_url, provider_key_id = self.service._get_provider_info( + "azure/gpt-4o" + ) self.assertEqual(provider, "azure") self.assertEqual(provider_key_id, self.provider_key_azure.id) - provider, model, base_url, provider_key_id = self.service._get_provider_info("bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0") + provider, model, base_url, provider_key_id = self.service._get_provider_info( + "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0" + ) self.assertEqual(provider, "bedrock") self.assertEqual(provider_key_id, self.provider_key_bedrock.id) + async def test_get_provider_info_with_allowed_providers(self): + """Test getting provider info with allowed_provider_names optimization""" + # Test with allowed providers that should match + allowed_providers = {"openai", "custom"} + provider, model, base_url, provider_key_id = self.service._get_provider_info( + "custom/xxxx", allowed_provider_names=allowed_providers + ) + self.assertEqual(provider, "custom") + self.assertEqual(model, "custom-model") # Model mapping applied + self.assertEqual(provider_key_id, self.provider_key_custom.id) + + # Test with allowed providers that should not match (should fall back to original logic) + allowed_providers = {"openai", "anthropic"} + provider, model, base_url, provider_key_id = self.service._get_provider_info( + "xxxx", allowed_provider_names=allowed_providers + ) + # Should fall back to model mapping lookup + self.assertEqual(provider, "custom") + self.assertEqual(model, "custom-model") # Model mapping applied + self.assertEqual(provider_key_id, self.provider_key_custom.id) + + # Test case-insensitive matching + allowed_providers = {"OPENAI", "CUSTOM"} + provider, model, base_url, provider_key_id = self.service._get_provider_info( + "custom/xxxx", allowed_provider_names=allowed_providers + ) + self.assertEqual(provider, "custom") + self.assertEqual(model, "custom-model") # Model mapping applied + + # Test with no allowed providers (should use original logic) + provider, model, base_url, provider_key_id = self.service._get_provider_info( + "xxxx", allowed_provider_names=None + ) + self.assertEqual(provider, "custom") + self.assertEqual(model, "custom-model") # Model mapping applied + + async def test_merged_extract_provider_name_prefix_function(self): + """Test that the merged _extract_provider_name_prefix function works with both approaches""" + # Test with allowed_provider_names (optimized approach) + allowed_providers = {"openai", "custom"} + provider, model = self.service._extract_provider_name_prefix( + "custom/xxxx", allowed_provider_names=allowed_providers + ) + self.assertEqual(provider, "custom") + self.assertEqual(model, "xxxx") + + # Test without allowed_provider_names (comprehensive approach) + provider, model = self.service._extract_provider_name_prefix("custom/xxxx") + self.assertEqual(provider, "custom") + self.assertEqual(model, "xxxx") + + # Test with a provider that's not in allowed_provider_names + provider, model = self.service._extract_provider_name_prefix( + "anthropic/claude", allowed_provider_names={"openai", "custom"} + ) + self.assertIsNone(provider) + self.assertEqual(model, "anthropic/claude") + + # Test that both approaches return the same result for the same input + provider1, model1 = self.service._extract_provider_name_prefix( + "openai/gpt-4", allowed_provider_names={"openai", "anthropic"} + ) + provider2, model2 = self.service._extract_provider_name_prefix("openai/gpt-4") + self.assertEqual(provider1, provider2) + self.assertEqual(model1, model2) + + async def test_extract_provider_name_prefix_with_custom_provider(self): + """Test that _extract_provider_name_prefix works with custom providers not in adapter factory""" + # Test with a custom provider that's in the user's provider keys but not in adapter factory + provider, model = self.service._extract_provider_name_prefix("custom/xxxx") + self.assertEqual(provider, "custom") + self.assertEqual(model, "xxxx") + + # Test with a provider that's in both adapter factory and user keys + provider, model = self.service._extract_provider_name_prefix("openai/gpt-4") + self.assertEqual(provider, "openai") + self.assertEqual(model, "gpt-4") + + # Test with a provider that's only in adapter factory (should still work) + # Note: This test assumes there are providers in adapter factory that aren't in user keys + # We'll test with a model that doesn't have a prefix + provider, model = self.service._extract_provider_name_prefix("gpt-4") + self.assertIsNone(provider) + self.assertEqual(model, "gpt-4") + @patch("aiohttp.ClientSession.post") async def test_call_openai_api(self, mock_post): """Test calling the OpenAI API""" @@ -302,7 +429,9 @@ async def mock_json(): @patch("app.services.providers.adapter_factory.ProviderAdapterFactory.get_adapter") @patch("app.services.provider_service.decrypt_api_key") @patch("app.services.usage_stats_service.UsageStatsService.log_api_request") - async def test_process_request_routing(self, mock_log_usage, mock_decrypt, mock_get_adapter): + async def test_process_request_routing( + self, mock_log_usage, mock_decrypt, mock_get_adapter + ): """Test request routing based on model name""" # Create mocks for adapters mock_openai_adapter = MagicMock() @@ -323,37 +452,59 @@ async def test_process_request_routing(self, mock_log_usage, mock_decrypt, mock_ "encrypted_fireworks_key": "decrypted_fireworks_key", "encrypted_openrouter_key": "decrypted_openrouter_key", "encrypted_together_key": "decrypted_together_key", - "encrypted_azure_key": json.dumps({ - "api_key": "decrypted_azure_key", - "api_version": "2025-01-01-preview", - }), - "encrypted_bedrock_key": json.dumps({ - "api_key": "decrypted_bedrock_key", - "region_name": "us-east-1", - "aws_access_key_id": "decrypted_aws_access_key_id", - "aws_secret_access_key": "decrypted_aws_secret_access_key", - }), + "encrypted_azure_key": json.dumps( + { + "api_key": "decrypted_azure_key", + "api_version": "2025-01-01-preview", + } + ), + "encrypted_bedrock_key": json.dumps( + { + "api_key": "decrypted_bedrock_key", + "region_name": "us-east-1", + "aws_access_key_id": "decrypted_aws_access_key_id", + "aws_secret_access_key": "decrypted_aws_secret_access_key", + } + ), } mock_decrypt.side_effect = lambda key: decrypt_key_map[key] # Now we could mock the process_completion method - mock_openai_adapter.process_completion = AsyncMock(return_value={"id": "openai-response"}) + mock_openai_adapter.process_completion = AsyncMock( + return_value={"id": "openai-response"} + ) - mock_anthropic_adapter.process_completion = AsyncMock(return_value={"id": "anthropic-response"}) + mock_anthropic_adapter.process_completion = AsyncMock( + return_value={"id": "anthropic-response"} + ) - mock_gemini_adapter.process_completion = AsyncMock(return_value={"id": "gemini-response"}) + mock_gemini_adapter.process_completion = AsyncMock( + return_value={"id": "gemini-response"} + ) - mock_xai_adapter.process_completion = AsyncMock(return_value={"id": "xai-response"}) + mock_xai_adapter.process_completion = AsyncMock( + return_value={"id": "xai-response"} + ) - mock_fireworks_adapter.process_completion = AsyncMock(return_value={"id": "fireworks-response"}) + mock_fireworks_adapter.process_completion = AsyncMock( + return_value={"id": "fireworks-response"} + ) - mock_openrouter_adapter.process_completion = AsyncMock(return_value={"id": "openrouter-response"}) + mock_openrouter_adapter.process_completion = AsyncMock( + return_value={"id": "openrouter-response"} + ) - mock_together_adapter.process_completion = AsyncMock(return_value={"id": "together-response"}) + mock_together_adapter.process_completion = AsyncMock( + return_value={"id": "together-response"} + ) - mock_azure_adapter.process_completion = AsyncMock(return_value={"id": "azure-response"}) + mock_azure_adapter.process_completion = AsyncMock( + return_value={"id": "azure-response"} + ) - mock_bedrock_adapter.process_completion = AsyncMock(return_value={"id": "bedrock-response"}) + mock_bedrock_adapter.process_completion = AsyncMock( + return_value={"id": "bedrock-response"} + ) # Configure get_adapter to return the appropriate mock provider_mapping = { @@ -367,9 +518,9 @@ async def test_process_request_routing(self, mock_log_usage, mock_decrypt, mock_ "azure": mock_azure_adapter, "bedrock": mock_bedrock_adapter, } - mock_get_adapter.side_effect = lambda provider, base_url, config: provider_mapping[ - provider - ] + mock_get_adapter.side_effect = ( + lambda provider, base_url, config: provider_mapping[provider] + ) # Test OpenAI routing result_openai = await self.service.process_request( @@ -442,7 +593,8 @@ async def test_process_request_routing(self, mock_log_usage, mock_decrypt, mock_ # Test Bedrock routing result_bedrock = await self.service.process_request( - "chat/completions", {"model": "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"} + "chat/completions", + {"model": "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"}, ) self.assertEqual(result_bedrock["id"], "bedrock-response") mock_bedrock_adapter.process_completion.assert_called_once()