diff --git a/app/services/providers/azure_adapter.py b/app/services/providers/azure_adapter.py index 04936b3..e19bed0 100644 --- a/app/services/providers/azure_adapter.py +++ b/app/services/providers/azure_adapter.py @@ -90,8 +90,6 @@ def process_streaming_chunk(chunk: bytes): return chunk return chunk - - async def process_completion( self, endpoint: str, @@ -109,6 +107,23 @@ async def process_completion( } return await super().process_completion(endpoint, payload, api_key, base_url, query_params) + async def process_embeddings( + self, + endpoint: str, + payload: dict[str, Any], + api_key: str, + ) -> Any: + """Process an embeddings request using Azure API""" + # Azure API requires the model to be in the path + model_id = payload["model"] + del payload["model"] + base_url = f"{self._base_url}/openai/deployments/{model_id}" + + query_params = { + "api-version": self.api_version, + } + return await super().process_embeddings(endpoint, payload, api_key, base_url, query_params) + async def list_models(self, api_key: str) -> list[str]: base_url = f"{self._base_url}/openai" query_params = {