Add batch_generate method to ModelAdapter for concurrent prompt proce…#16
Add batch_generate method to ModelAdapter for concurrent prompt proce…#16yanurag-dev wants to merge 2 commits intoBenchwise:developfrom
Conversation
…ssing Implemented the batch_generate method in the ModelAdapter class to handle generating responses for multiple prompts concurrently, with rate limit protection and retry logic. Added logging for better traceability of processing and error handling. Also included comprehensive unit tests for the new functionality, covering various scenarios such as empty prompts, maintaining order, handling rate limits, and testing with different adapters.
WalkthroughAdds an async ModelAdapter.batch_generate method to process multiple prompts concurrently with optional batching, semaphore-based concurrency limits, per-prompt retry/backoff for rate limits, order-preserving aggregation, and logging. A new comprehensive test module validates behavior across adapters and edge cases. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant ModelAdapter
participant Semaphore
participant LLM_API
User->>ModelAdapter: batch_generate(prompts, batch_size?, max_concurrent?)
activate ModelAdapter
Note over ModelAdapter: Split prompts into batches (if batch_size)
loop For each batch
loop For each prompt in batch (up to max_concurrent)
ModelAdapter->>Semaphore: acquire()
activate Semaphore
Semaphore-->>ModelAdapter: permit
deactivate Semaphore
ModelAdapter->>LLM_API: generate(prompt)
activate LLM_API
alt Success
LLM_API-->>ModelAdapter: result
else Rate limit (429)
LLM_API-->>ModelAdapter: rate-limit error
Note right of ModelAdapter: exponential backoff + jitter\nretry up to max_retries
ModelAdapter->>LLM_API: generate(prompt) [retry]
LLM_API-->>ModelAdapter: result or final error
else Non-retryable error
LLM_API-->>ModelAdapter: error
end
deactivate LLM_API
ModelAdapter->>Semaphore: release()
end
end
Note over ModelAdapter: Aggregate results preserving original order
ModelAdapter-->>User: List[str] (responses / error messages)
deactivate ModelAdapter
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Possibly related issues
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (4)
benchwise/models.py (3)
68-108: Semaphore held during retry delays reduces effective concurrency.The retry loop (lines 71-99) executes inside the
async with semaphore:block, meaning a semaphore slot is occupied duringawait asyncio.sleep(total_delay)on line 98. This reduces the number of active concurrent requests when retries are happening.Consider releasing and reacquiring the semaphore around the sleep:
async def process_single_prompt(index: int, prompt: str) -> tuple[int, str]: """Process a single prompt with retry logic and rate limit handling.""" - async with semaphore: - for attempt in range(max_retries + 1): - try: - # Call the generate method with a single prompt - results = await self.generate([prompt], **kwargs) - return (index, results[0]) - - except Exception as e: - error_str = str(e).lower() - - # Check if this is a rate limit error - is_rate_limit = any([ - "rate limit" in error_str, - "429" in error_str, - "too many requests" in error_str, - "quota" in error_str, - ]) - - if is_rate_limit and attempt < max_retries: - # Calculate exponential backoff with jitter - delay = min(base_delay * (2 ** attempt), max_delay) - jitter = random.uniform(0, 0.1 * delay) - total_delay = delay + jitter - - logger.warning( - f"Rate limit hit for prompt {index}, retrying after {total_delay:.2f}s " - f"(attempt {attempt + 1}/{max_retries})" - ) - await asyncio.sleep(total_delay) - continue - - # Non-rate-limit error or max retries exceeded - error_message = f"Error: {str(e)}" - if attempt == max_retries: - logger.error(f"Max retries exceeded for prompt {index}: {e}") - return (index, error_message) - - # Should never reach here, but just in case - return (index, "Error: Unknown error occurred") + for attempt in range(max_retries + 1): + try: + async with semaphore: + # Call the generate method with a single prompt + results = await self.generate([prompt], **kwargs) + return (index, results[0]) + + except Exception as e: + error_str = str(e).lower() + + # Check if this is a rate limit error + is_rate_limit = any([ + "rate limit" in error_str, + "429" in error_str, + "too many requests" in error_str, + "quota" in error_str, + ]) + + if is_rate_limit and attempt < max_retries: + # Calculate exponential backoff with jitter + delay = min(base_delay * (2 ** attempt), max_delay) + jitter = random.uniform(0, 0.1 * delay) + total_delay = delay + jitter + + logger.warning( + f"Rate limit hit for prompt {index}, retrying after {total_delay:.2f}s " + f"(attempt {attempt + 1}/{max_retries})" + ) + await asyncio.sleep(total_delay) + continue + + # Non-rate-limit error or max retries exceeded + error_message = f"Error: {str(e)}" + if attempt == max_retries: + logger.error(f"Max retries exceeded for prompt {index}: {e}") + return (index, error_message) + + # Should never reach here, but just in case + return (index, "Error: Unknown error occurred")
104-104: Preferlogging.exceptionfor automatic traceback inclusion.When logging errors within an exception handler,
logging.exceptionautomatically includes the traceback, which aids debugging.Apply this diff:
- logger.error(f"Max retries exceeded for prompt {index}: {e}") + logger.exception(f"Max retries exceeded for prompt {index}: {e}")
110-128:batch_sizeparameter has no effective impact on processing.The
batch_sizeis used to split prompts into batches (lines 111-114) and iterate through them (lines 123-127), but all tasks are created simultaneously and gathered at once (line 130). This meansbatch_sizedoesn't actually control how many prompts are processed concurrently—onlymax_concurrentdoes.If the intent is to process batches sequentially (to better control memory or API quotas), consider awaiting each batch before starting the next. Otherwise, consider removing the
batch_sizeparameter to simplify the API.Option 1: Process batches sequentially
- # Process all prompts concurrently (up to max_concurrent at a time) - all_tasks = [] + # Process batches sequentially + all_results = [] for batch_idx, batch in enumerate(batches): batch_start_idx = batch_idx * batch_size + batch_tasks = [] for i, prompt in enumerate(batch): task = process_single_prompt(batch_start_idx + i, prompt) - all_tasks.append(task) + batch_tasks.append(task) + # Await this batch before moving to the next + batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) + all_results.extend(batch_results) - # Gather all results with exception handling - results_with_indices = await asyncio.gather(*all_tasks, return_exceptions=True) + results_with_indices = all_resultsOption 2: Remove
batch_sizeparameter (simpler)If sequential batching isn't needed, remove the parameter and simplify the code.
tests/test_batch_generate.py (1)
239-259: LGTM: Validates large-scale processing.The test effectively verifies that batch_generate can handle larger datasets while maintaining order.
Consider removing or converting the print statement at line 259 to use pytest's output capturing or logging:
- # Should complete reasonably quickly due to concurrency - print(f"Processed {len(prompts)} prompts in {end_time - start_time:.2f}s")If timing information is valuable for CI monitoring, use pytest markers or logging instead.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
benchwise/models.py(2 hunks)tests/test_batch_generate.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/test_batch_generate.py (1)
benchwise/models.py (12)
ModelAdapter(8-146)OpenAIAdapter(149-211)AnthropicAdapter(214-276)MockAdapter(372-391)get_model_adapter(394-417)batch_generate(30-146)generate(16-18)generate(173-198)generate(238-263)generate(296-313)generate(342-361)generate(378-383)
benchwise/models.py (2)
benchwise/datasets.py (1)
prompts(49-62)tests/conftest.py (1)
generate(180-186)
🪛 Ruff (0.14.8)
tests/test_batch_generate.py
119-119: Unused function argument: kwargs
(ARG001)
125-125: Create your own exception
(TRY002)
125-125: Avoid specifying long messages outside the exception class
(TRY003)
142-142: Unused function argument: prompts
(ARG001)
142-142: Unused function argument: kwargs
(ARG001)
143-143: Create your own exception
(TRY002)
143-143: Avoid specifying long messages outside the exception class
(TRY003)
161-161: Unused function argument: prompts
(ARG001)
161-161: Unused function argument: kwargs
(ARG001)
168-168: Create your own exception
(TRY002)
168-168: Avoid specifying long messages outside the exception class
(TRY003)
193-193: Unused function argument: prompts
(ARG001)
193-193: Unused function argument: kwargs
(ARG001)
196-196: Create your own exception
(TRY002)
196-196: Avoid specifying long messages outside the exception class
(TRY003)
290-290: Unused function argument: prompts
(ARG001)
290-290: Unused function argument: kwargs
(ARG001)
293-293: Avoid specifying long messages outside the exception class
(TRY003)
312-312: Unused function argument: prompts
(ARG001)
312-312: Unused function argument: kwargs
(ARG001)
316-316: Create your own exception
(TRY002)
316-316: Avoid specifying long messages outside the exception class
(TRY003)
335-335: Unused function argument: prompts
(ARG001)
335-335: Unused function argument: kwargs
(ARG001)
339-339: Create your own exception
(TRY002)
339-339: Avoid specifying long messages outside the exception class
(TRY003)
benchwise/models.py
77-77: Do not catch blind exception: Exception
(BLE001)
91-91: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
102-102: Use explicit conversion flag
Replace with conversion flag
(RUF010)
104-104: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
136-136: Use explicit conversion flag
Replace with conversion flag
(RUF010)
🔇 Additional comments (18)
benchwise/models.py (5)
3-5: LGTM: Necessary imports for async concurrency control.The new imports support the batch processing functionality:
asynciofor concurrency control viaSemaphore,randomfor backoff jitter, andloggingfor traceability.
60-61: LGTM: Appropriate early return.The empty prompts check prevents unnecessary processing and returns the expected empty list.
63-66: LGTM: Proper concurrency control setup.Creating a semaphore per call correctly limits the concurrency for this specific batch_generate invocation.
116-119: LGTM: Helpful logging for processing visibility.The info-level log provides useful context for monitoring batch operations.
140-146: LGTM: Correct order restoration and response extraction.The sort-by-index approach correctly restores the original prompt order, assuming indices are assigned correctly upstream.
tests/test_batch_generate.py (13)
18-28: LGTM: Solid basic functionality test.The test verifies correct response count and content matching for basic batch generation.
30-43: LGTM: Good edge case coverage.These tests appropriately verify behavior with empty and single-element prompt lists.
45-55: LGTM: Validates order preservation under concurrency.This test correctly verifies that concurrent processing maintains the original prompt order in the results.
57-70: LGTM: Tests different batch size configurations.The test validates that varying
batch_sizedoesn't break functionality. However, it doesn't verify that batch sizing actually controls processing behavior (e.g., memory usage or sequential batch execution), which aligns with the implementation issue noted inbenchwise/models.py.
72-111: LGTM: Effectively validates concurrency limits.The concurrent execution tracking correctly verifies that
max_concurrentis respected.
113-136: LGTM: Validates rate limit retry mechanism.The test correctly verifies that rate-limited requests are retried and eventually succeed.
138-153: LGTM: Validates max retry exhaustion.The test correctly verifies that persistent rate limit errors eventually result in error responses after exhausting retries.
155-185: LGTM: Validates partial failure handling.The test correctly verifies that failures in some prompts don't prevent successful processing of others, and that failures appear in the correct positions.
187-213: LGTM: Validates exponential backoff behavior.The test correctly verifies that retry delays follow exponential backoff. The timing assertion at line 213 accounts for the expected minimum delay but could be sensitive to system load or slow CI environments.
If this test becomes flaky in CI, consider either:
- Adding a larger margin (e.g.,
>= 0.25instead of>= 0.3)- Mocking
asyncio.sleepto verify call arguments instead of measuring wall-clock time
215-237: LGTM: Validates kwargs propagation.The test correctly verifies that additional keyword arguments are passed through to the underlying
generatemethod.
261-350: LGTM: Comprehensive error handling and adapter compatibility tests.The test suite thoroughly validates:
- Cross-adapter compatibility
- Factory function integration
- Correct retry behavior for different error types (rate limit vs non-rate limit)
- Detection of various rate limit error formats (429, quota exceeded)
353-404: LGTM: Integration tests validate real adapter behavior.The integration tests appropriately mock external API clients and verify that
batch_generatecorrectly invokes the underlying adapter'sgeneratemethod the expected number of times.
407-413: LGTM: Standard test runner.The
__main__block allows convenient direct execution of the test file.
| # Gather all results with exception handling | ||
| results_with_indices = await asyncio.gather(*all_tasks, return_exceptions=True) | ||
|
|
||
| # Handle any exceptions that weren't caught | ||
| processed_results = [] | ||
| for item in results_with_indices: | ||
| if isinstance(item, Exception): | ||
| processed_results.append((len(processed_results), f"Error: {str(item)}")) | ||
| else: | ||
| processed_results.append(item) |
There was a problem hiding this comment.
Incorrect index assignment for uncaught exceptions.
Line 136 uses len(processed_results) as the index for uncaught exceptions, which assumes exceptions are processed in order and no previous exceptions occurred. However, if an exception occurs at index 2 and this is the first item processed, it would be assigned index 0 instead of 2, breaking order preservation.
Since process_single_prompt already returns (index, result) tuples and catches exceptions internally, exceptions reaching this point indicate a serious bug (e.g., in the exception handling itself). Consider logging these as critical errors instead of trying to recover:
# Handle any exceptions that weren't caught
processed_results = []
for item in results_with_indices:
if isinstance(item, Exception):
- processed_results.append((len(processed_results), f"Error: {str(item)}"))
+ logger.critical(f"Unexpected exception escaped process_single_prompt: {item}")
+ # This should never happen, indicates a bug in process_single_prompt
+ raise item
else:
processed_results.append(item)Alternatively, if you want to be defensive, you need to track which index failed. However, since asyncio.gather doesn't preserve position information when returning exceptions, you'd need to use asyncio.create_task with tracking or asyncio.as_completed.
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.14.8)
136-136: Use explicit conversion flag
Replace with conversion flag
(RUF010)
🤖 Prompt for AI Agents
In benchwise/models.py around lines 129 to 138, the code assigns indices for
exceptions using len(processed_results) which can mislabel positions; instead
iterate with enumerate over results_with_indices so you use the original
position (e.g., for idx, item in enumerate(results_with_indices):) and when item
is an Exception append (idx, f"Error: {str(item)}"); keep the else branch as-is
to append the returned (index, result) tuples; additionally log or raise a
critical error when an exception appears here since process_single_prompt should
have handled it.
There was a problem hiding this comment.
@CodeRabbit If exception occurred at index 2 so there would be two item already in processed_results and length will be 2, so exception will be added on index 2 not on 0.
I'm quoting you on this
Line 136 uses len(processed_results) as the index for uncaught exceptions, which assumes exceptions are processed in order and no previous exceptions occurred. However, if an exception occurs at index 2 and this is the first item processed, it would be assigned index 0 instead of 2, breaking order preservation.
There was a problem hiding this comment.
@yanurag-dev You're absolutely right, and I apologize for the confusion in my analysis.
Since asyncio.gather() preserves the order of results and we iterate through results_with_indices sequentially, len(processed_results) at each iteration does indeed equal the current index. So if an exception occurs at position 2, by the time we reach it in the iteration, processed_results would already contain 2 items (from indices 0 and 1), making len(processed_results) == 2, which is the correct index.
The code is actually working as intended. This exception handling block serves as defensive code for truly unexpected errors (like a bug in process_single_prompt itself or the semaphore failing), since process_single_prompt is designed to catch all exceptions internally and always return a (index, result_or_error) tuple.
Thank you for the correction!
There was a problem hiding this comment.
Actionable comments posted: 0
🧹 Nitpick comments (3)
benchwise/models.py (3)
109-116: Unreachable code and logging improvement.Two minor issues:
- Lines 115-116 are unreachable - the
forloop always returns at either line 83 (success) or line 113 (error).- Per static analysis (TRY400),
logging.exceptionat line 112 would include the traceback for better debugging.# Non-rate-limit error or max retries exceeded - error_message = f"Error: {str(e)}" + error_message = f"Error: {e!s}" if attempt == max_retries: - logger.error(f"Max retries exceeded for prompt {index}: {e}") + logger.exception(f"Max retries exceeded for prompt {index}") return (index, error_message) - - # Should never reach here, but just in case - return (index, "Error: Unknown error occurred")
118-138: Clarify the role ofbatch_sizevsmax_concurrent.The current implementation queues all prompts as tasks upfront (lines 130-135) and gathers them together (line 138). The
semaphorecontrols concurrency, not thebatch_size. The batching loop (lines 119-122) only affects index calculation but doesn't create pauses between batches.This means:
max_concurrent- limits how many prompts run simultaneously ✓batch_size- currently has no effect on execution flowIf
batch_sizeis intended for future use (e.g., inter-batch delays, per-batch logging), that's fine. Otherwise, consider removing it to avoid user confusion, or implement actual batched execution with optional delays between batches.If batch_size is intentionally a no-op for now, consider adding a docstring note:
- batch_size: Number of prompts to process per batch. If None, uses config value or default: 50 + batch_size: Number of prompts per batch for logging/organization purposes. + Actual concurrency is controlled by max_concurrent. If None, uses config value or default: 50Alternatively, simplify by removing unused batching:
- # Split prompts into batches - batches = [] - for i in range(0, len(prompts), batch_size): - batch = prompts[i:i + batch_size] - batches.append(batch) - logger.info( - f"Processing {len(prompts)} prompts in {len(batches)} batch(es) " - f"with max {max_concurrent} concurrent requests" + f"Processing {len(prompts)} prompts with max {max_concurrent} concurrent requests" ) # Process all prompts concurrently (up to max_concurrent at a time) all_tasks = [] - for batch_idx, batch in enumerate(batches): - batch_start_idx = batch_idx * batch_size - for i, prompt in enumerate(batch): - task = process_single_prompt(batch_start_idx + i, prompt) - all_tasks.append(task) + for i, prompt in enumerate(prompts): + all_tasks.append(process_single_prompt(i, prompt))
88-94: Rate limit detection relies on string matching.The rate limit detection checks for keywords in the error string. This is pragmatic but may miss provider-specific error formats or break if error messages change.
Consider also checking for exception types from specific provider SDKs when available (e.g.,
openai.RateLimitError,anthropic.RateLimitError), or document that subclasses can override this detection logic.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
benchwise/models.py(2 hunks)
🧰 Additional context used
🪛 Ruff (0.14.8)
benchwise/models.py
85-85: Do not catch blind exception: Exception
(BLE001)
99-99: Standard pseudo-random generators are not suitable for cryptographic purposes
(S311)
110-110: Use explicit conversion flag
Replace with conversion flag
(RUF010)
112-112: Use logging.exception instead of logging.error
Replace with exception
(TRY400)
144-144: Use explicit conversion flag
Replace with conversion flag
(RUF010)
🔇 Additional comments (1)
benchwise/models.py (1)
3-5: LGTM!The imports are appropriate for the async batch processing, concurrency control, and logging functionality.
…ssing
Implemented the batch_generate method in the ModelAdapter class to handle generating responses for multiple prompts concurrently, with rate limit protection and retry logic. Added logging for better traceability of processing and error handling.
Also included comprehensive unit tests for the new functionality, covering various scenarios such as empty prompts, maintaining order, handling rate limits, and testing with different adapters.
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.