Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions benchwise/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
import asyncio
import random
import logging


class ModelAdapter(ABC):
Expand All @@ -24,7 +27,133 @@ def get_cost_estimate(self, input_tokens: int, output_tokens: int) -> float:
"""Estimate cost for given token counts."""
pass

async def batch_generate(
self,
prompts: List[str],
batch_size: Optional[int] = None,
max_concurrent: Optional[int] = None,
max_retries: Optional[int] = None,
base_delay: Optional[float] = None,
max_delay: Optional[float] = None,
**kwargs
) -> List[str]:
"""
Generate responses for a list of prompts with concurrent processing and rate limit protection.

Args:
prompts: List of prompts to process
batch_size: Number of prompts to process per batch. If None, uses config value or default: 50
max_concurrent: Maximum number of concurrent requests. If None, uses config value or default: 5
max_retries: Maximum number of retry attempts for rate limits. If None, uses config value or default: 3
base_delay: Base delay in seconds for exponential backoff. If None, uses config value or default: 1.0
max_delay: Maximum delay in seconds between retries. If None, uses config value or default: 60.0
**kwargs: Additional arguments passed to generate()

Returns:
List of responses in the same order as input prompts

Example:
>>> adapter = OpenAIAdapter("gpt-3.5-turbo", config={"max_concurrent": 10})
>>> prompts = ["Hello", "How are you?", "Tell me a joke"]
>>> responses = await adapter.batch_generate(prompts) # Uses max_concurrent=10 from config
>>> responses = await adapter.batch_generate(prompts, max_concurrent=5) # Overrides config
"""
if not prompts:
return []

logger = logging.getLogger("benchwise.models")

# Get values from config if not explicitly provided, with hardcoded defaults as fallback
batch_size = batch_size if batch_size is not None else self.config.get("batch_size", 50)
max_concurrent = max_concurrent if max_concurrent is not None else self.config.get("max_concurrent", 5)
max_retries = max_retries if max_retries is not None else self.config.get("max_retries", 3)
base_delay = base_delay if base_delay is not None else self.config.get("base_delay", 1.0)
max_delay = max_delay if max_delay is not None else self.config.get("max_delay", 60.0)

# Create semaphore for this call to limit concurrency
semaphore = asyncio.Semaphore(max_concurrent)

async def process_single_prompt(index: int, prompt: str) -> tuple[int, str]:
"""Process a single prompt with retry logic and rate limit handling."""
async with semaphore:
for attempt in range(max_retries + 1):
try:
# Call the generate method with a single prompt
results = await self.generate([prompt], **kwargs)
return (index, results[0])

except Exception as e:
error_str = str(e).lower()

# Check if this is a rate limit error
is_rate_limit = any([
"rate limit" in error_str,
"429" in error_str,
"too many requests" in error_str,
"quota" in error_str,
])

if is_rate_limit and attempt < max_retries:
# Calculate exponential backoff with jitter
delay = min(base_delay * (2 ** attempt), max_delay)
jitter = random.uniform(0, 0.1 * delay)
total_delay = delay + jitter

logger.warning(
f"Rate limit hit for prompt {index}, retrying after {total_delay:.2f}s "
f"(attempt {attempt + 1}/{max_retries})"
)
await asyncio.sleep(total_delay)
continue

# Non-rate-limit error or max retries exceeded
error_message = f"Error: {str(e)}"
if attempt == max_retries:
logger.error(f"Max retries exceeded for prompt {index}: {e}")
return (index, error_message)

# Should never reach here, but just in case
return (index, "Error: Unknown error occurred")

# Split prompts into batches
batches = []
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i + batch_size]
batches.append(batch)

logger.info(
f"Processing {len(prompts)} prompts in {len(batches)} batch(es) "
f"with max {max_concurrent} concurrent requests"
)

# Process all prompts concurrently (up to max_concurrent at a time)
all_tasks = []
for batch_idx, batch in enumerate(batches):
batch_start_idx = batch_idx * batch_size
for i, prompt in enumerate(batch):
task = process_single_prompt(batch_start_idx + i, prompt)
all_tasks.append(task)

# Gather all results with exception handling
results_with_indices = await asyncio.gather(*all_tasks, return_exceptions=True)

# Handle any exceptions that weren't caught
processed_results = []
for item in results_with_indices:
if isinstance(item, Exception):
processed_results.append((len(processed_results), f"Error: {str(item)}"))
else:
processed_results.append(item)
Comment on lines +137 to +146
Copy link
Copy Markdown

@coderabbitai coderabbitai bot Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Incorrect index assignment for uncaught exceptions.

Line 136 uses len(processed_results) as the index for uncaught exceptions, which assumes exceptions are processed in order and no previous exceptions occurred. However, if an exception occurs at index 2 and this is the first item processed, it would be assigned index 0 instead of 2, breaking order preservation.

Since process_single_prompt already returns (index, result) tuples and catches exceptions internally, exceptions reaching this point indicate a serious bug (e.g., in the exception handling itself). Consider logging these as critical errors instead of trying to recover:

 # Handle any exceptions that weren't caught
 processed_results = []
 for item in results_with_indices:
     if isinstance(item, Exception):
-        processed_results.append((len(processed_results), f"Error: {str(item)}"))
+        logger.critical(f"Unexpected exception escaped process_single_prompt: {item}")
+        # This should never happen, indicates a bug in process_single_prompt
+        raise item
     else:
         processed_results.append(item)

Alternatively, if you want to be defensive, you need to track which index failed. However, since asyncio.gather doesn't preserve position information when returning exceptions, you'd need to use asyncio.create_task with tracking or asyncio.as_completed.

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.8)

136-136: Use explicit conversion flag

Replace with conversion flag

(RUF010)

🤖 Prompt for AI Agents
In benchwise/models.py around lines 129 to 138, the code assigns indices for
exceptions using len(processed_results) which can mislabel positions; instead
iterate with enumerate over results_with_indices so you use the original
position (e.g., for idx, item in enumerate(results_with_indices):) and when item
is an Exception append (idx, f"Error: {str(item)}"); keep the else branch as-is
to append the returned (index, result) tuples; additionally log or raise a
critical error when an exception appears here since process_single_prompt should
have handled it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CodeRabbit If exception occurred at index 2 so there would be two item already in processed_results and length will be 2, so exception will be added on index 2 not on 0.

I'm quoting you on this

Line 136 uses len(processed_results) as the index for uncaught exceptions, which assumes exceptions are processed in order and no previous exceptions occurred. However, if an exception occurs at index 2 and this is the first item processed, it would be assigned index 0 instead of 2, breaking order preservation.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanurag-dev You're absolutely right, and I apologize for the confusion in my analysis.

Since asyncio.gather() preserves the order of results and we iterate through results_with_indices sequentially, len(processed_results) at each iteration does indeed equal the current index. So if an exception occurs at position 2, by the time we reach it in the iteration, processed_results would already contain 2 items (from indices 0 and 1), making len(processed_results) == 2, which is the correct index.

The code is actually working as intended. This exception handling block serves as defensive code for truly unexpected errors (like a bug in process_single_prompt itself or the semaphore failing), since process_single_prompt is designed to catch all exceptions internally and always return a (index, result_or_error) tuple.

Thank you for the correction!


# Sort by index to maintain original order and extract just the responses
processed_results.sort(key=lambda x: x[0])
responses = [result[1] for result in processed_results]

logger.info(f"Completed processing {len(responses)} prompts")

return responses


class OpenAIAdapter(ModelAdapter):
"""Adapter for OpenAI models."""

Expand Down
Loading