Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 6 additions & 20 deletions app/services/provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,6 @@ def _get_adapters(self) -> dict[str, ProviderAdapter]:
ProviderService._adapters_cache = ProviderAdapterFactory.get_all_adapters()
return ProviderService._adapters_cache

def _ensure_model_mapping_dict(self, model_mapping: Any) -> dict[str, Any]:
"""Ensure model_mapping is a dictionary, handling cases where it might be a string."""
if isinstance(model_mapping, dict):
return model_mapping
elif isinstance(model_mapping, str):
try:
import json
return json.loads(model_mapping) if model_mapping else {}
except (json.JSONDecodeError, TypeError):
return {}
else:
return {}

async def _load_provider_keys(self) -> dict[str, dict[str, Any]]:
"""Load all provider keys for the user synchronously, with lazy loading and caching."""
if self._keys_loaded:
Expand Down Expand Up @@ -182,7 +169,7 @@ async def _load_provider_keys(self) -> dict[str, dict[str, Any]]:

keys = {}
for provider_key in provider_key_records:
model_mapping = self._ensure_model_mapping_dict(provider_key.model_mapping or {})
model_mapping = provider_key.model_mapping or {}

keys[provider_key.provider_name] = {
"api_key": decrypt_api_key(provider_key.encrypted_api_key),
Expand Down Expand Up @@ -234,7 +221,7 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]:

keys = {}
for provider_key in provider_key_records:
model_mapping = self._ensure_model_mapping_dict(provider_key.model_mapping or {})
model_mapping = provider_key.model_mapping or {}

keys[provider_key.provider_name] = {
"api_key": decrypt_api_key(provider_key.encrypted_api_key),
Expand Down Expand Up @@ -298,7 +285,7 @@ def _get_provider_info_with_prefix(

provider_data = self.provider_keys[matching_provider]

model_mapping = self._ensure_model_mapping_dict(provider_data.get("model_mapping", {}))
model_mapping = provider_data.get("model_mapping", {})
mapped_model = model_mapping.get(model_name, model_name)
return (
matching_provider,
Expand All @@ -321,7 +308,7 @@ def _find_provider_for_unprefixed_model(

# Check custom model mappings
for provider_name, provider_data in sorted_providers:
model_mapping = self._ensure_model_mapping_dict(provider_data.get("model_mapping", {}))
model_mapping = provider_data.get("model_mapping", {})
if model in model_mapping:
mapped_model = model_mapping[model]
return (
Expand Down Expand Up @@ -382,8 +369,7 @@ async def list_models(

# Create a cache key unique to this provider config
base_url = provider_data.get("base_url", "default")
model_mapping = self._ensure_model_mapping_dict(provider_data.get("model_mapping", {}))
cache_key = f"{base_url}:{hash(frozenset(model_mapping.items()))}"
cache_key = f"{base_url}:{hash(frozenset(provider_data.get('model_mapping', {}).items()))}"

# Check if we have cached models for this provider
cached_models = await self.get_cached_models(provider_name, cache_key)
Expand All @@ -401,7 +387,7 @@ async def _list_models_helper(
) -> list[dict[str, Any]]:
try:
model_names = await adapter.list_models(api_key)
model_mapping = self._ensure_model_mapping_dict(provider_data.get("model_mapping", {}))
model_mapping = provider_data.get("model_mapping", {})
reverse_model_mapping = {v: k for k, v in model_mapping.items()}
provider_models = [
{
Expand Down
317 changes: 0 additions & 317 deletions tests/unit_tests/test_model_mapping_fix.py

This file was deleted.

Loading