diff --git a/app/services/providers/openai_adapter.py b/app/services/providers/openai_adapter.py index 933778b..a608dc2 100644 --- a/app/services/providers/openai_adapter.py +++ b/app/services/providers/openai_adapter.py @@ -40,9 +40,16 @@ def get_model_id(self, payload: dict[str, Any]) -> str: logger.error(f"Model ID not found in payload for {self.provider_name}") raise BaseInvalidRequestException( provider_name=self.provider_name, - error=ValueError("Model ID not found in payload") + error=ValueError("Model ID not found in payload"), ) + def _ensure_list( + self, value: str | list[str] | list[int] | list[list[int]] + ) -> list[str] | list[int] | list[list[int]]: + if not isinstance(value, list): + return [value] + return value + async def list_models( self, api_key: str, @@ -70,11 +77,13 @@ async def list_models( ): if response.status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"List Models API error for {self.provider_name}: {error_text}") + logger.error( + f"List Models API error for {self.provider_name}: {error_text}" + ) raise ProviderAPIException( provider_name=self.provider_name, error_code=response.status, - error_message=error_text + error_message=error_text, ) resp = await response.json() @@ -82,7 +91,8 @@ async def list_models( models_list = resp["data"] if isinstance(resp, dict) else resp self.OPENAI_MODEL_MAPPING = { - d.get("name", self.get_model_id(d)): self.get_model_id(d) for d in models_list + d.get("name", self.get_model_id(d)): self.get_model_id(d) + for d in models_list } models = [self.get_model_id(d) for d in models_list] @@ -122,11 +132,13 @@ async def stream_response() -> AsyncGenerator[bytes, None]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Completion Streaming API error for {self.provider_name}: {error_text}") + logger.error( + f"Completion Streaming API error for {self.provider_name}: {error_text}" + ) raise ProviderAPIException( provider_name=self.provider_name, error_code=response.status, - error_message=error_text + error_message=error_text, ) # Stream the response back @@ -148,11 +160,13 @@ async def stream_response() -> AsyncGenerator[bytes, None]: ): if response.status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Completion API error for {self.provider_name}: {error_text}") + logger.error( + f"Completion API error for {self.provider_name}: {error_text}" + ) raise ProviderAPIException( provider_name=self.provider_name, error_code=response.status, - error_message=error_text + error_message=error_text, ) return await response.json() @@ -177,11 +191,13 @@ async def process_image_generation( ): if response.status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Image Generation API error for {self.provider_name}: {error_text}") + logger.error( + f"Image Generation API error for {self.provider_name}: {error_text}" + ) raise ProviderAPIException( provider_name=self.provider_name, error_code=response.status, - error_message=error_text + error_message=error_text, ) return await response.json() @@ -210,11 +226,11 @@ async def process_image_edits( raise ProviderAPIException( provider_name=self.provider_name, error_code=response.status, - error_message=error_text + error_message=error_text, ) return await response.json() - + async def process_embeddings( self, endpoint: str, @@ -230,6 +246,9 @@ async def process_embeddings( "Content-Type": "application/json", } + # process single and batch jobs + payload["input"] = self._ensure_list(payload["input"]) + # inpput_type is for cohere embeddings only if "input_type" in payload: del payload["input_type"] @@ -239,15 +258,19 @@ async def process_embeddings( async with ( aiohttp.ClientSession() as session, - session.post(url, headers=headers, json=payload, params=query_params) as response, + session.post( + url, headers=headers, json=payload, params=query_params + ) as response, ): if response.status != HTTPStatus.OK: error_text = await response.text() - logger.error(f"Embeddings API error for {self.provider_name}: {error_text}") + logger.error( + f"Embeddings API error for {self.provider_name}: {error_text}" + ) raise ProviderAPIException( provider_name=self.provider_name, error_code=response.status, - error_message=error_text + error_message=error_text, ) return await response.json() diff --git a/app/services/providers/tensorblock_adapter.py b/app/services/providers/tensorblock_adapter.py index 9fef10c..0c350ad 100644 --- a/app/services/providers/tensorblock_adapter.py +++ b/app/services/providers/tensorblock_adapter.py @@ -3,10 +3,8 @@ from .azure_adapter import AzureAdapter TENSORBLOCK_MODELS = [ - "gpt-4.1", "gpt-4.1-mini", "gpt-4.1-nano", - "gpt-4o", "gpt-4o-mini", "o3-mini", "text-embedding-3-large", diff --git a/tests/unit_tests/docs/anthropic/chat_completion_response_1.json b/tests/unit_tests/assets/anthropic/chat_completion_response_1.json similarity index 100% rename from tests/unit_tests/docs/anthropic/chat_completion_response_1.json rename to tests/unit_tests/assets/anthropic/chat_completion_response_1.json diff --git a/tests/unit_tests/docs/anthropic/chat_completion_streaming_response_1.json b/tests/unit_tests/assets/anthropic/chat_completion_streaming_response_1.json similarity index 100% rename from tests/unit_tests/docs/anthropic/chat_completion_streaming_response_1.json rename to tests/unit_tests/assets/anthropic/chat_completion_streaming_response_1.json diff --git a/tests/unit_tests/docs/anthropic/list_models.json b/tests/unit_tests/assets/anthropic/list_models.json similarity index 100% rename from tests/unit_tests/docs/anthropic/list_models.json rename to tests/unit_tests/assets/anthropic/list_models.json diff --git a/tests/unit_tests/docs/google/chat_completion_response_1.json b/tests/unit_tests/assets/google/chat_completion_response_1.json similarity index 100% rename from tests/unit_tests/docs/google/chat_completion_response_1.json rename to tests/unit_tests/assets/google/chat_completion_response_1.json diff --git a/tests/unit_tests/docs/google/chat_completion_streaming_response_1.json b/tests/unit_tests/assets/google/chat_completion_streaming_response_1.json similarity index 100% rename from tests/unit_tests/docs/google/chat_completion_streaming_response_1.json rename to tests/unit_tests/assets/google/chat_completion_streaming_response_1.json diff --git a/tests/unit_tests/docs/google/list_models.json b/tests/unit_tests/assets/google/list_models.json similarity index 100% rename from tests/unit_tests/docs/google/list_models.json rename to tests/unit_tests/assets/google/list_models.json diff --git a/tests/unit_tests/docs/openai/chat_completion_response_1.json b/tests/unit_tests/assets/openai/chat_completion_response_1.json similarity index 100% rename from tests/unit_tests/docs/openai/chat_completion_response_1.json rename to tests/unit_tests/assets/openai/chat_completion_response_1.json diff --git a/tests/unit_tests/docs/openai/chat_completion_streaming_response_1.json b/tests/unit_tests/assets/openai/chat_completion_streaming_response_1.json similarity index 100% rename from tests/unit_tests/docs/openai/chat_completion_streaming_response_1.json rename to tests/unit_tests/assets/openai/chat_completion_streaming_response_1.json diff --git a/tests/unit_tests/assets/openai/embeddings_response.json b/tests/unit_tests/assets/openai/embeddings_response.json new file mode 100644 index 0000000..e296aba --- /dev/null +++ b/tests/unit_tests/assets/openai/embeddings_response.json @@ -0,0 +1,20 @@ +{ + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + -0.006929283495992422, + -0.005336422007530928, + -4.547132266452536e-05, + -0.024047503247857094 + ] + } + ], + "model": "text-embedding-ada-002", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } +} \ No newline at end of file diff --git a/tests/unit_tests/docs/openai/list_models.json b/tests/unit_tests/assets/openai/list_models.json similarity index 100% rename from tests/unit_tests/docs/openai/list_models.json rename to tests/unit_tests/assets/openai/list_models.json diff --git a/tests/unit_tests/test_openai_provider.py b/tests/unit_tests/test_openai_provider.py index 3bf3416..ebc2be4 100644 --- a/tests/unit_tests/test_openai_provider.py +++ b/tests/unit_tests/test_openai_provider.py @@ -30,6 +30,11 @@ ) as f: MOCK_CHAT_COMPLETION_STREAMING_RESPONSE_DATA = json.load(f) +with open( + os.path.join(CURRENT_DIR, "docs", "openai", "embeddings_response.json"), "r" +) as f: + MOCK_EMBEDDINGS_RESPONSE_DATA = json.load(f) + class TestOpenAIProvider(TestCase): def setUp(self): @@ -108,3 +113,20 @@ async def test_chat_completion_streaming(self): expected_model="gpt-4o-mini-2024-07-18", expected_message=OPENAAI_STANDARD_CHAT_COMPLETION_RESPONSE, ) + + async def test_process_embeddings(self): + payload = { + "model": "text-embedding-ada-002", + "input": ["hello", "world"], + } + with patch("aiohttp.ClientSession", ClientSessionMock()) as mock_session: + mock_session.responses = [(MOCK_EMBEDDINGS_RESPONSE_DATA, 200)] + + # Call the method + result = await self.adapter.process_embeddings( + api_key=self.api_key, payload=payload, endpoint="embeddings" + ) + self.assertEqual(result, MOCK_EMBEDDINGS_RESPONSE_DATA) + + # verify that the payload sent to openai has a list as input + self.assertIsInstance(mock_session.posted_json[0]["input"], list)