diff --git a/app/services/providers/openai_adapter.py b/app/services/providers/openai_adapter.py index a608dc2..b0e1a37 100644 --- a/app/services/providers/openai_adapter.py +++ b/app/services/providers/openai_adapter.py @@ -4,7 +4,10 @@ import aiohttp from app.core.logger import get_logger -from app.exceptions.exceptions import ProviderAPIException, BaseInvalidRequestException +from app.exceptions import ( + ProviderAPIException, + BaseInvalidRequestException, +) from .base import ProviderAdapter @@ -12,6 +15,9 @@ logger = get_logger(name="openai_adapter") +MAX_BATCH_SIZE = 2048 + + class OpenAIAdapter(ProviderAdapter): """Adapter for OpenAI API""" @@ -256,21 +262,36 @@ async def process_embeddings( url = f"{base_url or self._base_url}/{endpoint}" query_params = query_params or {} - async with ( - aiohttp.ClientSession() as session, - session.post( - url, headers=headers, json=payload, params=query_params - ) as response, - ): - if response.status != HTTPStatus.OK: - error_text = await response.text() - logger.error( - f"Embeddings API error for {self.provider_name}: {error_text}" - ) - raise ProviderAPIException( - provider_name=self.provider_name, - error_code=response.status, - error_message=error_text, - ) + all_embeddings = [] + for i in range(0, len(payload["input"]), MAX_BATCH_SIZE): + batch_payload = payload.copy() + batch_payload["input"] = payload["input"][i : i + MAX_BATCH_SIZE] - return await response.json() + async with ( + aiohttp.ClientSession() as session, + session.post( + url, headers=headers, json=batch_payload, params=query_params + ) as response, + ): + if response.status != HTTPStatus.OK: + error_text = await response.text() + logger.error( + f"Embeddings API error for {self.provider_name}: {error_text}" + ) + raise ProviderAPIException( + provider_name=self.provider_name, + error_code=response.status, + error_message=error_text, + ) + + response_json = await response.json() + all_embeddings.extend(response_json["data"]) + + # Combine the results into a single response + final_response = { + "object": "list", + "data": all_embeddings, + "model": response_json["model"], + "usage": response_json["usage"], + } + return final_response