Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions app/services/providers/openai_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


MAX_BATCH_SIZE = 2048
MAX_TOKENS_PER_BATCH = 8192 # OpenAI's limit for embeddings


class OpenAIAdapter(ProviderAdapter):
Expand Down Expand Up @@ -56,6 +57,49 @@ def _ensure_list(
return [value]
return value

def _estimate_tokens(self, text: str) -> int:
"""Estimate token count for a text string (rough approximation)"""
# Rough approximation: 1 token ≈ 4 characters for English text
# This is a conservative estimate
estimated = len(text) // 4 + 1

# Cap at a reasonable maximum to prevent extremely large batches
return min(estimated, MAX_TOKENS_PER_BATCH // 2)

def _create_token_aware_batches(self, inputs: list[str]) -> list[list[str]]:
"""Create batches based on token count rather than just input count"""
batches = []
current_batch = []
current_token_count = 0

for input_text in inputs:
estimated_tokens = self._estimate_tokens(input_text)

# If a single input exceeds the limit, it needs to be processed alone
if estimated_tokens > MAX_TOKENS_PER_BATCH:
logger.warning(f"Single input exceeds token limit ({estimated_tokens} tokens), processing alone")
if current_batch:
batches.append(current_batch)
batches.append([input_text])
current_batch = []
current_token_count = 0
continue

# If adding this input would exceed the token limit, start a new batch
if current_token_count + estimated_tokens > MAX_TOKENS_PER_BATCH and current_batch:
batches.append(current_batch)
current_batch = [input_text]
current_token_count = estimated_tokens
else:
current_batch.append(input_text)
current_token_count += estimated_tokens

# Add the last batch if it has content
if current_batch:
batches.append(current_batch)

return batches

async def list_models(
self,
api_key: str,
Expand Down Expand Up @@ -263,9 +307,17 @@ async def process_embeddings(
query_params = query_params or {}

all_embeddings = []
for i in range(0, len(payload["input"]), MAX_BATCH_SIZE):
total_usage = {"prompt_tokens": 0, "total_tokens": 0}

# Create token-aware batches
batches = self._create_token_aware_batches(payload["input"])

logger.info(f"Created {len(batches)} batches for {len(payload['input'])} inputs")

for i, batch_inputs in enumerate(batches):
logger.debug(f"Processing batch {i+1}/{len(batches)} with {len(batch_inputs)} inputs")
batch_payload = payload.copy()
batch_payload["input"] = payload["input"][i : i + MAX_BATCH_SIZE]
batch_payload["input"] = batch_inputs

async with (
aiohttp.ClientSession() as session,
Expand All @@ -286,12 +338,17 @@ async def process_embeddings(

response_json = await response.json()
all_embeddings.extend(response_json["data"])

# Accumulate usage statistics
if "usage" in response_json:
total_usage["prompt_tokens"] += response_json["usage"].get("prompt_tokens", 0)
total_usage["total_tokens"] += response_json["usage"].get("total_tokens", 0)

# Combine the results into a single response
final_response = {
"object": "list",
"data": all_embeddings,
"model": response_json["model"],
"usage": response_json["usage"],
"usage": total_usage,
}
return final_response
6 changes: 3 additions & 3 deletions tests/unit_tests/test_anthropic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

CURRENT_DIR = os.path.dirname(__file__)

with open(os.path.join(CURRENT_DIR, "docs", "anthropic", "list_models.json"), "r") as f:
with open(os.path.join(CURRENT_DIR, "assets", "anthropic", "list_models.json"), "r") as f:
MOCK_LIST_MODELS_RESPONSE_DATA = json.load(f)

with open(
os.path.join(CURRENT_DIR, "docs", "anthropic", "chat_completion_response_1.json"),
os.path.join(CURRENT_DIR, "assets", "anthropic", "chat_completion_response_1.json"),
"r",
) as f:
MOCK_CHAT_COMPLETION_RESPONSE_DATA = json.load(f)

with open(
os.path.join(
CURRENT_DIR, "docs", "anthropic", "chat_completion_streaming_response_1.json"
CURRENT_DIR, "assets", "anthropic", "chat_completion_streaming_response_1.json"
),
"r",
) as f:
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/test_google_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@

CURRENT_DIR = os.path.dirname(__file__)

with open(os.path.join(CURRENT_DIR, "docs", "google", "list_models.json"), "r") as f:
with open(os.path.join(CURRENT_DIR, "assets", "google", "list_models.json"), "r") as f:
MOCK_LIST_MODELS_RESPONSE_DATA = json.load(f)

with open(
os.path.join(CURRENT_DIR, "docs", "google", "chat_completion_response_1.json"), "r"
os.path.join(CURRENT_DIR, "assets", "google", "chat_completion_response_1.json"), "r"
) as f:
MOCK_CHAT_COMPLETION_RESPONSE_DATA = json.load(f)

with open(
os.path.join(
CURRENT_DIR, "docs", "google", "chat_completion_streaming_response_1.json"
CURRENT_DIR, "assets", "google", "chat_completion_streaming_response_1.json"
),
"r",
) as f:
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_tests/test_openai_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,24 @@

CURRENT_DIR = os.path.dirname(__file__)

with open(os.path.join(CURRENT_DIR, "docs", "openai", "list_models.json"), "r") as f:
with open(os.path.join(CURRENT_DIR, "assets", "openai", "list_models.json"), "r") as f:
MOCK_LIST_MODELS_RESPONSE_DATA = json.load(f)

with open(
os.path.join(CURRENT_DIR, "docs", "openai", "chat_completion_response_1.json"), "r"
os.path.join(CURRENT_DIR, "assets", "openai", "chat_completion_response_1.json"), "r"
) as f:
MOCK_CHAT_COMPLETION_RESPONSE_DATA = json.load(f)

with open(
os.path.join(
CURRENT_DIR, "docs", "openai", "chat_completion_streaming_response_1.json"
CURRENT_DIR, "assets", "openai", "chat_completion_streaming_response_1.json"
),
"r",
) as f:
MOCK_CHAT_COMPLETION_STREAMING_RESPONSE_DATA = json.load(f)

with open(
os.path.join(CURRENT_DIR, "docs", "openai", "embeddings_response.json"), "r"
os.path.join(CURRENT_DIR, "assets", "openai", "embeddings_response.json"), "r"
) as f:
MOCK_EMBEDDINGS_RESPONSE_DATA = json.load(f)

Expand Down
Loading