Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions app/services/provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,19 @@ def _get_adapters(self) -> dict[str, ProviderAdapter]:
ProviderService._adapters_cache = ProviderAdapterFactory.get_all_adapters()
return ProviderService._adapters_cache

def _ensure_model_mapping_dict(self, model_mapping: Any) -> dict[str, Any]:
"""Ensure model_mapping is a dictionary, handling cases where it might be a string."""
if isinstance(model_mapping, dict):
return model_mapping
elif isinstance(model_mapping, str):
try:
import json
return json.loads(model_mapping) if model_mapping else {}
except (json.JSONDecodeError, TypeError):
return {}
else:
return {}

async def _load_provider_keys(self) -> dict[str, dict[str, Any]]:
"""Load all provider keys for the user synchronously, with lazy loading and caching."""
if self._keys_loaded:
Expand Down Expand Up @@ -169,7 +182,7 @@ async def _load_provider_keys(self) -> dict[str, dict[str, Any]]:

keys = {}
for provider_key in provider_key_records:
model_mapping = provider_key.model_mapping or {}
model_mapping = self._ensure_model_mapping_dict(provider_key.model_mapping or {})

keys[provider_key.provider_name] = {
"api_key": decrypt_api_key(provider_key.encrypted_api_key),
Expand Down Expand Up @@ -221,7 +234,7 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]:

keys = {}
for provider_key in provider_key_records:
model_mapping = provider_key.model_mapping or {}
model_mapping = self._ensure_model_mapping_dict(provider_key.model_mapping or {})

keys[provider_key.provider_name] = {
"api_key": decrypt_api_key(provider_key.encrypted_api_key),
Expand Down Expand Up @@ -285,7 +298,7 @@ def _get_provider_info_with_prefix(

provider_data = self.provider_keys[matching_provider]

model_mapping = provider_data.get("model_mapping", {})
model_mapping = self._ensure_model_mapping_dict(provider_data.get("model_mapping", {}))
mapped_model = model_mapping.get(model_name, model_name)
return (
matching_provider,
Expand All @@ -308,7 +321,7 @@ def _find_provider_for_unprefixed_model(

# Check custom model mappings
for provider_name, provider_data in sorted_providers:
model_mapping = provider_data.get("model_mapping", {})
model_mapping = self._ensure_model_mapping_dict(provider_data.get("model_mapping", {}))
if model in model_mapping:
mapped_model = model_mapping[model]
return (
Expand Down Expand Up @@ -369,7 +382,8 @@ async def list_models(

# Create a cache key unique to this provider config
base_url = provider_data.get("base_url", "default")
cache_key = f"{base_url}:{hash(frozenset(provider_data.get('model_mapping', {}).items()))}"
model_mapping = self._ensure_model_mapping_dict(provider_data.get("model_mapping", {}))
cache_key = f"{base_url}:{hash(frozenset(model_mapping.items()))}"

# Check if we have cached models for this provider
cached_models = await self.get_cached_models(provider_name, cache_key)
Expand All @@ -387,7 +401,7 @@ async def _list_models_helper(
) -> list[dict[str, Any]]:
try:
model_names = await adapter.list_models(api_key)
model_mapping = provider_data.get("model_mapping", {})
model_mapping = self._ensure_model_mapping_dict(provider_data.get("model_mapping", {}))
reverse_model_mapping = {v: k for k, v in model_mapping.items()}
provider_models = [
{
Expand Down
317 changes: 317 additions & 0 deletions tests/unit_tests/test_model_mapping_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
import pytest
from unittest.mock import AsyncMock, MagicMock
from sqlalchemy.ext.asyncio import AsyncSession

from app.services.provider_service import ProviderService
from app.models.provider_key import ProviderKey
from app.models.user import User


class TestModelMappingFix:
"""Test cases for the model_mapping string-to-dict conversion fix."""

def test_ensure_model_mapping_dict_helper(self):
"""Test the _ensure_model_mapping_dict helper method with various inputs."""
# Create a ProviderService instance (db can be None for this test)
ps = ProviderService(1, None)

# Test valid JSON string
result = ps._ensure_model_mapping_dict('{"gpt-4": "gpt-4-turbo", "claude": "claude-3-opus"}')
assert result == {"gpt-4": "gpt-4-turbo", "claude": "claude-3-opus"}

# Test empty string
result = ps._ensure_model_mapping_dict("")
assert result == {}

# Test None
result = ps._ensure_model_mapping_dict(None)
assert result == {}

# Test already valid dict
test_dict = {"test": "value"}
result = ps._ensure_model_mapping_dict(test_dict)
assert result == test_dict
assert result is test_dict # Should return the same object

# Test invalid JSON string
result = ps._ensure_model_mapping_dict('{invalid json}')
assert result == {}

# Test malformed JSON string
result = ps._ensure_model_mapping_dict('{"key": "value",}')
assert result == {}

# Test non-string, non-dict input
result = ps._ensure_model_mapping_dict(123)
assert result == {}

result = ps._ensure_model_mapping_dict([])
assert result == {}

@pytest.mark.asyncio
async def test_list_models_with_string_model_mapping(self):
"""Test that list_models works correctly when model_mapping is stored as a string."""
# Mock database session
mock_db = AsyncMock(spec=AsyncSession)

# Mock user
mock_user = MagicMock(spec=User)
mock_user.id = 1

# Create ProviderService instance
ps = ProviderService(1, mock_db)

# Mock the database query to return a provider key with string model_mapping
mock_provider_key = MagicMock(spec=ProviderKey)
mock_provider_key.provider_name = "openai"
mock_provider_key.encrypted_api_key = "encrypted_key"
mock_provider_key.base_url = "https://api.openai.com"
# This simulates old data where model_mapping was stored as a string
mock_provider_key.model_mapping = '{"gpt-4": "gpt-4-turbo"}'

# Mock the database query result
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_provider_key]
mock_db.execute.return_value = mock_result

# Mock the cache to return None (no cached data)
ps._keys_loaded = False

# Mock the provider adapter
mock_adapter = MagicMock()
mock_adapter.list_models = AsyncMock(return_value=["gpt-4", "gpt-3.5-turbo"])

# Mock the adapter factory
with pytest.MonkeyPatch().context() as m:
m.setattr("app.services.provider_service.ProviderAdapterFactory.get_adapter_cls",
lambda x: MagicMock(deserialize_api_key_config=lambda x: ("api_key", {})))
m.setattr("app.services.provider_service.ProviderAdapterFactory.get_adapter",
lambda x, y, z: mock_adapter)
m.setattr("app.services.provider_service.decrypt_api_key", lambda x: "decrypted_key")
m.setattr("app.services.provider_service.async_provider_service_cache.get",
AsyncMock(return_value=None))
m.setattr("app.services.provider_service.async_provider_service_cache.set",
AsyncMock())

# Call list_models - this should not raise an error
result = await ps.list_models()

# Verify the result
assert isinstance(result, list)
assert len(result) == 2 # Two models returned

# Verify the models have the correct structure
for model in result:
assert "id" in model
assert "display_name" in model
assert "object" in model
assert "owned_by" in model
assert model["object"] == "model"
assert model["owned_by"] == "openai"

@pytest.mark.asyncio
async def test_list_models_with_invalid_json_string(self):
"""Test that list_models handles invalid JSON strings gracefully."""
# Mock database session
mock_db = AsyncMock(spec=AsyncSession)

# Mock user
mock_user = MagicMock(spec=User)
mock_user.id = 1

# Create ProviderService instance
ps = ProviderService(1, mock_db)

# Mock the database query to return a provider key with invalid JSON string
mock_provider_key = MagicMock(spec=ProviderKey)
mock_provider_key.provider_name = "openai"
mock_provider_key.encrypted_api_key = "encrypted_key"
mock_provider_key.base_url = "https://api.openai.com"
# This simulates corrupted data
mock_provider_key.model_mapping = '{invalid json string'

# Mock the database query result
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_provider_key]
mock_db.execute.return_value = mock_result

# Mock the cache to return None (no cached data)
ps._keys_loaded = False

# Mock the provider adapter
mock_adapter = MagicMock()
mock_adapter.list_models = AsyncMock(return_value=["gpt-4", "gpt-3.5-turbo"])

# Mock the adapter factory
with pytest.MonkeyPatch().context() as m:
m.setattr("app.services.provider_service.ProviderAdapterFactory.get_adapter_cls",
lambda x: MagicMock(deserialize_api_key_config=lambda x: ("api_key", {})))
m.setattr("app.services.provider_service.ProviderAdapterFactory.get_adapter",
lambda x, y, z: mock_adapter)
m.setattr("app.services.provider_service.decrypt_api_key", lambda x: "decrypted_key")
m.setattr("app.services.provider_service.async_provider_service_cache.get",
AsyncMock(return_value=None))
m.setattr("app.services.provider_service.async_provider_service_cache.set",
AsyncMock())

# Call list_models - this should not raise an error
result = await ps.list_models()

# Verify the result
assert isinstance(result, list)
assert len(result) == 2 # Two models returned

# Since model_mapping was invalid, display_name should be the same as the model name
for model in result:
assert model["display_name"] == model["id"].split("/")[1]

@pytest.mark.asyncio
async def test_list_models_with_none_model_mapping(self):
"""Test that list_models works correctly when model_mapping is None."""
# Mock database session
mock_db = AsyncMock(spec=AsyncSession)

# Mock user
mock_user = MagicMock(spec=User)
mock_user.id = 1

# Create ProviderService instance
ps = ProviderService(1, mock_db)

# Mock the database query to return a provider key with None model_mapping
mock_provider_key = MagicMock(spec=ProviderKey)
mock_provider_key.provider_name = "openai"
mock_provider_key.encrypted_api_key = "encrypted_key"
mock_provider_key.base_url = "https://api.openai.com"
mock_provider_key.model_mapping = None

# Mock the database query result
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_provider_key]
mock_db.execute.return_value = mock_result

# Mock the cache to return None (no cached data)
ps._keys_loaded = False

# Mock the provider adapter
mock_adapter = MagicMock()
mock_adapter.list_models = AsyncMock(return_value=["gpt-4", "gpt-3.5-turbo"])

# Mock the adapter factory
with pytest.MonkeyPatch().context() as m:
m.setattr("app.services.provider_service.ProviderAdapterFactory.get_adapter_cls",
lambda x: MagicMock(deserialize_api_key_config=lambda x: ("api_key", {})))
m.setattr("app.services.provider_service.ProviderAdapterFactory.get_adapter",
lambda x, y, z: mock_adapter)
m.setattr("app.services.provider_service.decrypt_api_key", lambda x: "decrypted_key")
m.setattr("app.services.provider_service.async_provider_service_cache.get",
AsyncMock(return_value=None))
m.setattr("app.services.provider_service.async_provider_service_cache.set",
AsyncMock())

# Call list_models - this should not raise an error
result = await ps.list_models()

# Verify the result
assert isinstance(result, list)
assert len(result) == 2 # Two models returned

# Since model_mapping was None, display_name should be the same as the model name
for model in result:
assert model["display_name"] == model["id"].split("/")[1]

def test_get_provider_info_with_string_model_mapping(self):
"""Test that _get_provider_info_with_prefix works with string model_mapping."""
ps = ProviderService(1, None)

# Mock provider_keys with string model_mapping
ps.provider_keys = {
"openai": {
"api_key": "test_key",
"base_url": "https://api.openai.com",
"model_mapping": '{"custom-gpt": "gpt-4"}'
}
}
ps._keys_loaded = True

# Test that it works correctly
provider_name, mapped_model, base_url = ps._get_provider_info_with_prefix(
"openai", "custom-gpt", "openai/custom-gpt"
)

assert provider_name == "openai"
assert mapped_model == "gpt-4" # Should be mapped correctly
assert base_url == "https://api.openai.com"

def test_find_provider_for_unprefixed_model_with_string_model_mapping(self):
"""Test that _find_provider_for_unprefixed_model works with string model_mapping."""
ps = ProviderService(1, None)

# Mock provider_keys with string model_mapping
ps.provider_keys = {
"openai": {
"api_key": "test_key",
"base_url": "https://api.openai.com",
"model_mapping": '{"custom-gpt": "gpt-4"}'
}
}
ps._keys_loaded = True

# Test that it works correctly
provider_name, mapped_model, base_url = ps._find_provider_for_unprefixed_model("custom-gpt")

assert provider_name == "openai"
assert mapped_model == "gpt-4" # Should be mapped correctly
assert base_url == "https://api.openai.com"

def test_original_error_scenario_prevention(self):
"""Test that the original 'str' object has no attribute 'items' error is prevented."""
ps = ProviderService(1, None)

# Simulate the exact scenario that caused the original error
# This would have caused the error before our fix
provider_data = {
"base_url": "https://api.openai.com",
"model_mapping": '{"gpt-4": "gpt-4-turbo"}' # String instead of dict
}

# This line would have failed before our fix:
# cache_key = f"{base_url}:{hash(frozenset(provider_data.get('model_mapping', {}).items()))}"
# Because provider_data.get('model_mapping', {}) would return a string, and strings don't have .items()

# Now with our fix, this should work:
base_url = provider_data.get("base_url", "default")
model_mapping = ps._ensure_model_mapping_dict(provider_data.get("model_mapping", {}))
cache_key = f"{base_url}:{hash(frozenset(model_mapping.items()))}"

# Verify that no error was raised and we got a valid cache key
assert isinstance(cache_key, str)
assert "https://api.openai.com" in cache_key
assert len(cache_key) > 0

def test_cache_key_generation_with_various_model_mappings(self):
"""Test that cache key generation works with various model_mapping types."""
ps = ProviderService(1, None)

test_cases = [
# (model_mapping, expected_type)
('{"gpt-4": "gpt-4-turbo"}', dict),
('', dict),
(None, dict),
('{invalid json}', dict),
({"valid": "dict"}, dict),
]

for model_mapping, expected_type in test_cases:
result = ps._ensure_model_mapping_dict(model_mapping)
assert isinstance(result, expected_type)

# Test that we can call .items() on the result
items = result.items()
assert hasattr(items, '__iter__') # Should be iterable

# Test cache key generation
base_url = "https://api.openai.com"
cache_key = f"{base_url}:{hash(frozenset(result.items()))}"
assert isinstance(cache_key, str)
assert base_url in cache_key
Loading