diff --git a/app/services/providers/openai_adapter.py b/app/services/providers/openai_adapter.py index b3db07c..20f5036 100644 --- a/app/services/providers/openai_adapter.py +++ b/app/services/providers/openai_adapter.py @@ -16,6 +16,7 @@ MAX_BATCH_SIZE = 2048 +MAX_TOKENS_PER_BATCH = 8192 # OpenAI's limit for embeddings class OpenAIAdapter(ProviderAdapter): @@ -56,6 +57,49 @@ def _ensure_list( return [value] return value + def _estimate_tokens(self, text: str) -> int: + """Estimate token count for a text string (rough approximation)""" + # Rough approximation: 1 token ≈ 4 characters for English text + # This is a conservative estimate + estimated = len(text) // 4 + 1 + + # Cap at a reasonable maximum to prevent extremely large batches + return min(estimated, MAX_TOKENS_PER_BATCH // 2) + + def _create_token_aware_batches(self, inputs: list[str]) -> list[list[str]]: + """Create batches based on token count rather than just input count""" + batches = [] + current_batch = [] + current_token_count = 0 + + for input_text in inputs: + estimated_tokens = self._estimate_tokens(input_text) + + # If a single input exceeds the limit, it needs to be processed alone + if estimated_tokens > MAX_TOKENS_PER_BATCH: + logger.warning(f"Single input exceeds token limit ({estimated_tokens} tokens), processing alone") + if current_batch: + batches.append(current_batch) + batches.append([input_text]) + current_batch = [] + current_token_count = 0 + continue + + # If adding this input would exceed the token limit, start a new batch + if current_token_count + estimated_tokens > MAX_TOKENS_PER_BATCH and current_batch: + batches.append(current_batch) + current_batch = [input_text] + current_token_count = estimated_tokens + else: + current_batch.append(input_text) + current_token_count += estimated_tokens + + # Add the last batch if it has content + if current_batch: + batches.append(current_batch) + + return batches + async def list_models( self, api_key: str, @@ -263,9 +307,17 @@ async def process_embeddings( query_params = query_params or {} all_embeddings = [] - for i in range(0, len(payload["input"]), MAX_BATCH_SIZE): + total_usage = {"prompt_tokens": 0, "total_tokens": 0} + + # Create token-aware batches + batches = self._create_token_aware_batches(payload["input"]) + + logger.info(f"Created {len(batches)} batches for {len(payload['input'])} inputs") + + for i, batch_inputs in enumerate(batches): + logger.debug(f"Processing batch {i+1}/{len(batches)} with {len(batch_inputs)} inputs") batch_payload = payload.copy() - batch_payload["input"] = payload["input"][i : i + MAX_BATCH_SIZE] + batch_payload["input"] = batch_inputs async with ( aiohttp.ClientSession() as session, @@ -286,12 +338,17 @@ async def process_embeddings( response_json = await response.json() all_embeddings.extend(response_json["data"]) + + # Accumulate usage statistics + if "usage" in response_json: + total_usage["prompt_tokens"] += response_json["usage"].get("prompt_tokens", 0) + total_usage["total_tokens"] += response_json["usage"].get("total_tokens", 0) # Combine the results into a single response final_response = { "object": "list", "data": all_embeddings, "model": response_json["model"], - "usage": response_json["usage"], + "usage": total_usage, } return final_response diff --git a/tests/unit_tests/test_anthropic_provider.py b/tests/unit_tests/test_anthropic_provider.py index 298ae03..9b33c7b 100644 --- a/tests/unit_tests/test_anthropic_provider.py +++ b/tests/unit_tests/test_anthropic_provider.py @@ -14,18 +14,18 @@ CURRENT_DIR = os.path.dirname(__file__) -with open(os.path.join(CURRENT_DIR, "docs", "anthropic", "list_models.json"), "r") as f: +with open(os.path.join(CURRENT_DIR, "assets", "anthropic", "list_models.json"), "r") as f: MOCK_LIST_MODELS_RESPONSE_DATA = json.load(f) with open( - os.path.join(CURRENT_DIR, "docs", "anthropic", "chat_completion_response_1.json"), + os.path.join(CURRENT_DIR, "assets", "anthropic", "chat_completion_response_1.json"), "r", ) as f: MOCK_CHAT_COMPLETION_RESPONSE_DATA = json.load(f) with open( os.path.join( - CURRENT_DIR, "docs", "anthropic", "chat_completion_streaming_response_1.json" + CURRENT_DIR, "assets", "anthropic", "chat_completion_streaming_response_1.json" ), "r", ) as f: diff --git a/tests/unit_tests/test_google_provider.py b/tests/unit_tests/test_google_provider.py index 9396ad1..be152e4 100644 --- a/tests/unit_tests/test_google_provider.py +++ b/tests/unit_tests/test_google_provider.py @@ -14,17 +14,17 @@ CURRENT_DIR = os.path.dirname(__file__) -with open(os.path.join(CURRENT_DIR, "docs", "google", "list_models.json"), "r") as f: +with open(os.path.join(CURRENT_DIR, "assets", "google", "list_models.json"), "r") as f: MOCK_LIST_MODELS_RESPONSE_DATA = json.load(f) with open( - os.path.join(CURRENT_DIR, "docs", "google", "chat_completion_response_1.json"), "r" + os.path.join(CURRENT_DIR, "assets", "google", "chat_completion_response_1.json"), "r" ) as f: MOCK_CHAT_COMPLETION_RESPONSE_DATA = json.load(f) with open( os.path.join( - CURRENT_DIR, "docs", "google", "chat_completion_streaming_response_1.json" + CURRENT_DIR, "assets", "google", "chat_completion_streaming_response_1.json" ), "r", ) as f: diff --git a/tests/unit_tests/test_openai_provider.py b/tests/unit_tests/test_openai_provider.py index ebc2be4..7d9edd3 100644 --- a/tests/unit_tests/test_openai_provider.py +++ b/tests/unit_tests/test_openai_provider.py @@ -14,24 +14,24 @@ CURRENT_DIR = os.path.dirname(__file__) -with open(os.path.join(CURRENT_DIR, "docs", "openai", "list_models.json"), "r") as f: +with open(os.path.join(CURRENT_DIR, "assets", "openai", "list_models.json"), "r") as f: MOCK_LIST_MODELS_RESPONSE_DATA = json.load(f) with open( - os.path.join(CURRENT_DIR, "docs", "openai", "chat_completion_response_1.json"), "r" + os.path.join(CURRENT_DIR, "assets", "openai", "chat_completion_response_1.json"), "r" ) as f: MOCK_CHAT_COMPLETION_RESPONSE_DATA = json.load(f) with open( os.path.join( - CURRENT_DIR, "docs", "openai", "chat_completion_streaming_response_1.json" + CURRENT_DIR, "assets", "openai", "chat_completion_streaming_response_1.json" ), "r", ) as f: MOCK_CHAT_COMPLETION_STREAMING_RESPONSE_DATA = json.load(f) with open( - os.path.join(CURRENT_DIR, "docs", "openai", "embeddings_response.json"), "r" + os.path.join(CURRENT_DIR, "assets", "openai", "embeddings_response.json"), "r" ) as f: MOCK_EMBEDDINGS_RESPONSE_DATA = json.load(f)