Skip to content

Add batch_generate method to ModelAdapter for concurrent prompt proce…#16

Open
yanurag-dev wants to merge 2 commits intoBenchwise:developfrom
yanurag-dev:feat/Add-batch-processing
Open

Add batch_generate method to ModelAdapter for concurrent prompt proce…#16
yanurag-dev wants to merge 2 commits intoBenchwise:developfrom
yanurag-dev:feat/Add-batch-processing

Conversation

@yanurag-dev
Copy link
Copy Markdown
Collaborator

@yanurag-dev yanurag-dev commented Dec 13, 2025

…ssing

Implemented the batch_generate method in the ModelAdapter class to handle generating responses for multiple prompts concurrently, with rate limit protection and retry logic. Added logging for better traceability of processing and error handling.

Also included comprehensive unit tests for the new functionality, covering various scenarios such as empty prompts, maintaining order, handling rate limits, and testing with different adapters.

Summary by CodeRabbit

  • New Features
    • Added batch generation for processing multiple prompts with controlled concurrency, ordered results, per-prompt error handling, and automatic retries with exponential backoff and jitter.
  • Tests
    • Added comprehensive test coverage validating batching, order preservation, concurrency limits, retry/backoff behavior, error handling, and integration with multiple adapters.

✏️ Tip: You can customize this high-level summary in your review settings.

…ssing

Implemented the batch_generate method in the ModelAdapter class to handle generating responses for multiple prompts concurrently, with rate limit protection and retry logic. Added logging for better traceability of processing and error handling.

Also included comprehensive unit tests for the new functionality, covering various scenarios such as empty prompts, maintaining order, handling rate limits, and testing with different adapters.
@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Dec 13, 2025

Walkthrough

Adds an async ModelAdapter.batch_generate method to process multiple prompts concurrently with optional batching, semaphore-based concurrency limits, per-prompt retry/backoff for rate limits, order-preserving aggregation, and logging. A new comprehensive test module validates behavior across adapters and edge cases.

Changes

Cohort / File(s) Summary
Batch Generation Implementation
benchwise/models.py
Adds async def batch_generate(...) to ModelAdapter. Implements batching, optional max_concurrent enforced with asyncio.Semaphore, per-prompt retry on rate-limit errors with exponential backoff + jitter, preserves prompt order, and returns per-prompt responses or error messages.
Test Suite
tests/test_batch_generate.py
Adds tests covering basic and edge cases: empty/single prompts, ordering, batch_size/max_concurrent variations, retry/backoff behavior (including 429/quota cases), non-retryable errors, propagation of kwargs, large datasets, and mocked integration tests for OpenAI/Anthropic adapters.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant ModelAdapter
    participant Semaphore
    participant LLM_API

    User->>ModelAdapter: batch_generate(prompts, batch_size?, max_concurrent?)
    activate ModelAdapter
    Note over ModelAdapter: Split prompts into batches (if batch_size)

    loop For each batch
        loop For each prompt in batch (up to max_concurrent)
            ModelAdapter->>Semaphore: acquire()
            activate Semaphore
            Semaphore-->>ModelAdapter: permit
            deactivate Semaphore

            ModelAdapter->>LLM_API: generate(prompt)
            activate LLM_API

            alt Success
                LLM_API-->>ModelAdapter: result
            else Rate limit (429)
                LLM_API-->>ModelAdapter: rate-limit error
                Note right of ModelAdapter: exponential backoff + jitter\nretry up to max_retries
                ModelAdapter->>LLM_API: generate(prompt) [retry]
                LLM_API-->>ModelAdapter: result or final error
            else Non-retryable error
                LLM_API-->>ModelAdapter: error
            end

            deactivate LLM_API
            ModelAdapter->>Semaphore: release()
        end
    end

    Note over ModelAdapter: Aggregate results preserving original order
    ModelAdapter-->>User: List[str] (responses / error messages)
    deactivate ModelAdapter
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • Focus areas:
    • Verify correct acquire/release of asyncio.Semaphore and no deadlocks.
    • Confirm retry classification (429/quota) vs non-retryable errors.
    • Validate exponential backoff math, jitter bounds, and adherence to max_retries, base_delay, max_delay.
    • Ensure preserved input ordering when collecting concurrent results.
    • Review tests for realistic mocking of external clients and timing assertions.

Possibly related issues

  • Issue #7: Directly related — implements ModelAdapter.batch_generate with batching and semaphore-based concurrency plus retry/backoff, matching the issue description.

Poem

🐇 I hop through prompts both near and far,

Batches align like rows of star,
When 429s say "hold on tight,"
I jitter back and retry the fight,
Results returned in tidy rows — hooray, batched delight! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 70.97% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: adding a batch_generate method to ModelAdapter for concurrent prompt processing, which matches the primary addition in the changeset.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
benchwise/models.py (3)

68-108: Semaphore held during retry delays reduces effective concurrency.

The retry loop (lines 71-99) executes inside the async with semaphore: block, meaning a semaphore slot is occupied during await asyncio.sleep(total_delay) on line 98. This reduces the number of active concurrent requests when retries are happening.

Consider releasing and reacquiring the semaphore around the sleep:

 async def process_single_prompt(index: int, prompt: str) -> tuple[int, str]:
     """Process a single prompt with retry logic and rate limit handling."""
-    async with semaphore:
-        for attempt in range(max_retries + 1):
-            try:
-                # Call the generate method with a single prompt
-                results = await self.generate([prompt], **kwargs)
-                return (index, results[0])
-            
-            except Exception as e:
-                error_str = str(e).lower()
-                
-                # Check if this is a rate limit error
-                is_rate_limit = any([
-                    "rate limit" in error_str,
-                    "429" in error_str,
-                    "too many requests" in error_str,
-                    "quota" in error_str,
-                ])
-                
-                if is_rate_limit and attempt < max_retries:
-                    # Calculate exponential backoff with jitter
-                    delay = min(base_delay * (2 ** attempt), max_delay)
-                    jitter = random.uniform(0, 0.1 * delay)
-                    total_delay = delay + jitter
-                    
-                    logger.warning(
-                        f"Rate limit hit for prompt {index}, retrying after {total_delay:.2f}s "
-                        f"(attempt {attempt + 1}/{max_retries})"
-                    )
-                    await asyncio.sleep(total_delay)
-                    continue
-                
-                # Non-rate-limit error or max retries exceeded
-                error_message = f"Error: {str(e)}"
-                if attempt == max_retries:
-                    logger.error(f"Max retries exceeded for prompt {index}: {e}")
-                return (index, error_message)
-    
-    # Should never reach here, but just in case
-    return (index, "Error: Unknown error occurred")
+    for attempt in range(max_retries + 1):
+        try:
+            async with semaphore:
+                # Call the generate method with a single prompt
+                results = await self.generate([prompt], **kwargs)
+            return (index, results[0])
+        
+        except Exception as e:
+            error_str = str(e).lower()
+            
+            # Check if this is a rate limit error
+            is_rate_limit = any([
+                "rate limit" in error_str,
+                "429" in error_str,
+                "too many requests" in error_str,
+                "quota" in error_str,
+            ])
+            
+            if is_rate_limit and attempt < max_retries:
+                # Calculate exponential backoff with jitter
+                delay = min(base_delay * (2 ** attempt), max_delay)
+                jitter = random.uniform(0, 0.1 * delay)
+                total_delay = delay + jitter
+                
+                logger.warning(
+                    f"Rate limit hit for prompt {index}, retrying after {total_delay:.2f}s "
+                    f"(attempt {attempt + 1}/{max_retries})"
+                )
+                await asyncio.sleep(total_delay)
+                continue
+            
+            # Non-rate-limit error or max retries exceeded
+            error_message = f"Error: {str(e)}"
+            if attempt == max_retries:
+                logger.error(f"Max retries exceeded for prompt {index}: {e}")
+            return (index, error_message)
+    
+    # Should never reach here, but just in case
+    return (index, "Error: Unknown error occurred")

104-104: Prefer logging.exception for automatic traceback inclusion.

When logging errors within an exception handler, logging.exception automatically includes the traceback, which aids debugging.

Apply this diff:

-                logger.error(f"Max retries exceeded for prompt {index}: {e}")
+                logger.exception(f"Max retries exceeded for prompt {index}: {e}")

110-128: batch_size parameter has no effective impact on processing.

The batch_size is used to split prompts into batches (lines 111-114) and iterate through them (lines 123-127), but all tasks are created simultaneously and gathered at once (line 130). This means batch_size doesn't actually control how many prompts are processed concurrently—only max_concurrent does.

If the intent is to process batches sequentially (to better control memory or API quotas), consider awaiting each batch before starting the next. Otherwise, consider removing the batch_size parameter to simplify the API.

Option 1: Process batches sequentially

-        # Process all prompts concurrently (up to max_concurrent at a time)
-        all_tasks = []
+        # Process batches sequentially
+        all_results = []
         for batch_idx, batch in enumerate(batches):
             batch_start_idx = batch_idx * batch_size
+            batch_tasks = []
             for i, prompt in enumerate(batch):
                 task = process_single_prompt(batch_start_idx + i, prompt)
-                all_tasks.append(task)
+                batch_tasks.append(task)
+            # Await this batch before moving to the next
+            batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
+            all_results.extend(batch_results)
         
-        # Gather all results with exception handling
-        results_with_indices = await asyncio.gather(*all_tasks, return_exceptions=True)
+        results_with_indices = all_results

Option 2: Remove batch_size parameter (simpler)

If sequential batching isn't needed, remove the parameter and simplify the code.

tests/test_batch_generate.py (1)

239-259: LGTM: Validates large-scale processing.

The test effectively verifies that batch_generate can handle larger datasets while maintaining order.

Consider removing or converting the print statement at line 259 to use pytest's output capturing or logging:

-        # Should complete reasonably quickly due to concurrency
-        print(f"Processed {len(prompts)} prompts in {end_time - start_time:.2f}s")

If timing information is valuable for CI monitoring, use pytest markers or logging instead.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a2569de and 3b7439a.

📒 Files selected for processing (2)
  • benchwise/models.py (2 hunks)
  • tests/test_batch_generate.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/test_batch_generate.py (1)
benchwise/models.py (12)
  • ModelAdapter (8-146)
  • OpenAIAdapter (149-211)
  • AnthropicAdapter (214-276)
  • MockAdapter (372-391)
  • get_model_adapter (394-417)
  • batch_generate (30-146)
  • generate (16-18)
  • generate (173-198)
  • generate (238-263)
  • generate (296-313)
  • generate (342-361)
  • generate (378-383)
benchwise/models.py (2)
benchwise/datasets.py (1)
  • prompts (49-62)
tests/conftest.py (1)
  • generate (180-186)
🪛 Ruff (0.14.8)
tests/test_batch_generate.py

119-119: Unused function argument: kwargs

(ARG001)


125-125: Create your own exception

(TRY002)


125-125: Avoid specifying long messages outside the exception class

(TRY003)


142-142: Unused function argument: prompts

(ARG001)


142-142: Unused function argument: kwargs

(ARG001)


143-143: Create your own exception

(TRY002)


143-143: Avoid specifying long messages outside the exception class

(TRY003)


161-161: Unused function argument: prompts

(ARG001)


161-161: Unused function argument: kwargs

(ARG001)


168-168: Create your own exception

(TRY002)


168-168: Avoid specifying long messages outside the exception class

(TRY003)


193-193: Unused function argument: prompts

(ARG001)


193-193: Unused function argument: kwargs

(ARG001)


196-196: Create your own exception

(TRY002)


196-196: Avoid specifying long messages outside the exception class

(TRY003)


290-290: Unused function argument: prompts

(ARG001)


290-290: Unused function argument: kwargs

(ARG001)


293-293: Avoid specifying long messages outside the exception class

(TRY003)


312-312: Unused function argument: prompts

(ARG001)


312-312: Unused function argument: kwargs

(ARG001)


316-316: Create your own exception

(TRY002)


316-316: Avoid specifying long messages outside the exception class

(TRY003)


335-335: Unused function argument: prompts

(ARG001)


335-335: Unused function argument: kwargs

(ARG001)


339-339: Create your own exception

(TRY002)


339-339: Avoid specifying long messages outside the exception class

(TRY003)

benchwise/models.py

77-77: Do not catch blind exception: Exception

(BLE001)


91-91: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


102-102: Use explicit conversion flag

Replace with conversion flag

(RUF010)


104-104: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


136-136: Use explicit conversion flag

Replace with conversion flag

(RUF010)

🔇 Additional comments (18)
benchwise/models.py (5)

3-5: LGTM: Necessary imports for async concurrency control.

The new imports support the batch processing functionality: asyncio for concurrency control via Semaphore, random for backoff jitter, and logging for traceability.


60-61: LGTM: Appropriate early return.

The empty prompts check prevents unnecessary processing and returns the expected empty list.


63-66: LGTM: Proper concurrency control setup.

Creating a semaphore per call correctly limits the concurrency for this specific batch_generate invocation.


116-119: LGTM: Helpful logging for processing visibility.

The info-level log provides useful context for monitoring batch operations.


140-146: LGTM: Correct order restoration and response extraction.

The sort-by-index approach correctly restores the original prompt order, assuming indices are assigned correctly upstream.

tests/test_batch_generate.py (13)

18-28: LGTM: Solid basic functionality test.

The test verifies correct response count and content matching for basic batch generation.


30-43: LGTM: Good edge case coverage.

These tests appropriately verify behavior with empty and single-element prompt lists.


45-55: LGTM: Validates order preservation under concurrency.

This test correctly verifies that concurrent processing maintains the original prompt order in the results.


57-70: LGTM: Tests different batch size configurations.

The test validates that varying batch_size doesn't break functionality. However, it doesn't verify that batch sizing actually controls processing behavior (e.g., memory usage or sequential batch execution), which aligns with the implementation issue noted in benchwise/models.py.


72-111: LGTM: Effectively validates concurrency limits.

The concurrent execution tracking correctly verifies that max_concurrent is respected.


113-136: LGTM: Validates rate limit retry mechanism.

The test correctly verifies that rate-limited requests are retried and eventually succeed.


138-153: LGTM: Validates max retry exhaustion.

The test correctly verifies that persistent rate limit errors eventually result in error responses after exhausting retries.


155-185: LGTM: Validates partial failure handling.

The test correctly verifies that failures in some prompts don't prevent successful processing of others, and that failures appear in the correct positions.


187-213: LGTM: Validates exponential backoff behavior.

The test correctly verifies that retry delays follow exponential backoff. The timing assertion at line 213 accounts for the expected minimum delay but could be sensitive to system load or slow CI environments.

If this test becomes flaky in CI, consider either:

  1. Adding a larger margin (e.g., >= 0.25 instead of >= 0.3)
  2. Mocking asyncio.sleep to verify call arguments instead of measuring wall-clock time

215-237: LGTM: Validates kwargs propagation.

The test correctly verifies that additional keyword arguments are passed through to the underlying generate method.


261-350: LGTM: Comprehensive error handling and adapter compatibility tests.

The test suite thoroughly validates:

  • Cross-adapter compatibility
  • Factory function integration
  • Correct retry behavior for different error types (rate limit vs non-rate limit)
  • Detection of various rate limit error formats (429, quota exceeded)

353-404: LGTM: Integration tests validate real adapter behavior.

The integration tests appropriately mock external API clients and verify that batch_generate correctly invokes the underlying adapter's generate method the expected number of times.


407-413: LGTM: Standard test runner.

The __main__ block allows convenient direct execution of the test file.

Comment on lines +129 to +138
# Gather all results with exception handling
results_with_indices = await asyncio.gather(*all_tasks, return_exceptions=True)

# Handle any exceptions that weren't caught
processed_results = []
for item in results_with_indices:
if isinstance(item, Exception):
processed_results.append((len(processed_results), f"Error: {str(item)}"))
else:
processed_results.append(item)
Copy link
Copy Markdown

@coderabbitai coderabbitai bot Dec 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Incorrect index assignment for uncaught exceptions.

Line 136 uses len(processed_results) as the index for uncaught exceptions, which assumes exceptions are processed in order and no previous exceptions occurred. However, if an exception occurs at index 2 and this is the first item processed, it would be assigned index 0 instead of 2, breaking order preservation.

Since process_single_prompt already returns (index, result) tuples and catches exceptions internally, exceptions reaching this point indicate a serious bug (e.g., in the exception handling itself). Consider logging these as critical errors instead of trying to recover:

 # Handle any exceptions that weren't caught
 processed_results = []
 for item in results_with_indices:
     if isinstance(item, Exception):
-        processed_results.append((len(processed_results), f"Error: {str(item)}"))
+        logger.critical(f"Unexpected exception escaped process_single_prompt: {item}")
+        # This should never happen, indicates a bug in process_single_prompt
+        raise item
     else:
         processed_results.append(item)

Alternatively, if you want to be defensive, you need to track which index failed. However, since asyncio.gather doesn't preserve position information when returning exceptions, you'd need to use asyncio.create_task with tracking or asyncio.as_completed.

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.14.8)

136-136: Use explicit conversion flag

Replace with conversion flag

(RUF010)

🤖 Prompt for AI Agents
In benchwise/models.py around lines 129 to 138, the code assigns indices for
exceptions using len(processed_results) which can mislabel positions; instead
iterate with enumerate over results_with_indices so you use the original
position (e.g., for idx, item in enumerate(results_with_indices):) and when item
is an Exception append (idx, f"Error: {str(item)}"); keep the else branch as-is
to append the returned (index, result) tuples; additionally log or raise a
critical error when an exception appears here since process_single_prompt should
have handled it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CodeRabbit If exception occurred at index 2 so there would be two item already in processed_results and length will be 2, so exception will be added on index 2 not on 0.

I'm quoting you on this

Line 136 uses len(processed_results) as the index for uncaught exceptions, which assumes exceptions are processed in order and no previous exceptions occurred. However, if an exception occurs at index 2 and this is the first item processed, it would be assigned index 0 instead of 2, breaking order preservation.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanurag-dev You're absolutely right, and I apologize for the confusion in my analysis.

Since asyncio.gather() preserves the order of results and we iterate through results_with_indices sequentially, len(processed_results) at each iteration does indeed equal the current index. So if an exception occurs at position 2, by the time we reach it in the iteration, processed_results would already contain 2 items (from indices 0 and 1), making len(processed_results) == 2, which is the correct index.

The code is actually working as intended. This exception handling block serves as defensive code for truly unexpected errors (like a bug in process_single_prompt itself or the semaphore failing), since process_single_prompt is designed to catch all exceptions internally and always return a (index, result_or_error) tuple.

Thank you for the correction!

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
benchwise/models.py (3)

109-116: Unreachable code and logging improvement.

Two minor issues:

  1. Lines 115-116 are unreachable - the for loop always returns at either line 83 (success) or line 113 (error).
  2. Per static analysis (TRY400), logging.exception at line 112 would include the traceback for better debugging.
                         # Non-rate-limit error or max retries exceeded
-                        error_message = f"Error: {str(e)}"
+                        error_message = f"Error: {e!s}"
                         if attempt == max_retries:
-                            logger.error(f"Max retries exceeded for prompt {index}: {e}")
+                            logger.exception(f"Max retries exceeded for prompt {index}")
                         return (index, error_message)
-            
-            # Should never reach here, but just in case
-            return (index, "Error: Unknown error occurred")

118-138: Clarify the role of batch_size vs max_concurrent.

The current implementation queues all prompts as tasks upfront (lines 130-135) and gathers them together (line 138). The semaphore controls concurrency, not the batch_size. The batching loop (lines 119-122) only affects index calculation but doesn't create pauses between batches.

This means:

  • max_concurrent - limits how many prompts run simultaneously ✓
  • batch_size - currently has no effect on execution flow

If batch_size is intended for future use (e.g., inter-batch delays, per-batch logging), that's fine. Otherwise, consider removing it to avoid user confusion, or implement actual batched execution with optional delays between batches.

If batch_size is intentionally a no-op for now, consider adding a docstring note:

-            batch_size: Number of prompts to process per batch. If None, uses config value or default: 50
+            batch_size: Number of prompts per batch for logging/organization purposes. 
+                Actual concurrency is controlled by max_concurrent. If None, uses config value or default: 50

Alternatively, simplify by removing unused batching:

-        # Split prompts into batches
-        batches = []
-        for i in range(0, len(prompts), batch_size):
-            batch = prompts[i:i + batch_size]
-            batches.append(batch)
-        
         logger.info(
-            f"Processing {len(prompts)} prompts in {len(batches)} batch(es) "
-            f"with max {max_concurrent} concurrent requests"
+            f"Processing {len(prompts)} prompts with max {max_concurrent} concurrent requests"
         )
         
         # Process all prompts concurrently (up to max_concurrent at a time)
         all_tasks = []
-        for batch_idx, batch in enumerate(batches):
-            batch_start_idx = batch_idx * batch_size
-            for i, prompt in enumerate(batch):
-                task = process_single_prompt(batch_start_idx + i, prompt)
-                all_tasks.append(task)
+        for i, prompt in enumerate(prompts):
+            all_tasks.append(process_single_prompt(i, prompt))

88-94: Rate limit detection relies on string matching.

The rate limit detection checks for keywords in the error string. This is pragmatic but may miss provider-specific error formats or break if error messages change.

Consider also checking for exception types from specific provider SDKs when available (e.g., openai.RateLimitError, anthropic.RateLimitError), or document that subclasses can override this detection logic.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 3b7439a and 4c89806.

📒 Files selected for processing (1)
  • benchwise/models.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.14.8)
benchwise/models.py

85-85: Do not catch blind exception: Exception

(BLE001)


99-99: Standard pseudo-random generators are not suitable for cryptographic purposes

(S311)


110-110: Use explicit conversion flag

Replace with conversion flag

(RUF010)


112-112: Use logging.exception instead of logging.error

Replace with exception

(TRY400)


144-144: Use explicit conversion flag

Replace with conversion flag

(RUF010)

🔇 Additional comments (1)
benchwise/models.py (1)

3-5: LGTM!

The imports are appropriate for the async batch processing, concurrency control, and logging functionality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant