diff --git a/benchwise/models.py b/benchwise/models.py index cd5c88a..2bbeccd 100644 --- a/benchwise/models.py +++ b/benchwise/models.py @@ -1,5 +1,8 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional +import asyncio +import random +import logging class ModelAdapter(ABC): @@ -24,7 +27,133 @@ def get_cost_estimate(self, input_tokens: int, output_tokens: int) -> float: """Estimate cost for given token counts.""" pass + async def batch_generate( + self, + prompts: List[str], + batch_size: Optional[int] = None, + max_concurrent: Optional[int] = None, + max_retries: Optional[int] = None, + base_delay: Optional[float] = None, + max_delay: Optional[float] = None, + **kwargs + ) -> List[str]: + """ + Generate responses for a list of prompts with concurrent processing and rate limit protection. + + Args: + prompts: List of prompts to process + batch_size: Number of prompts to process per batch. If None, uses config value or default: 50 + max_concurrent: Maximum number of concurrent requests. If None, uses config value or default: 5 + max_retries: Maximum number of retry attempts for rate limits. If None, uses config value or default: 3 + base_delay: Base delay in seconds for exponential backoff. If None, uses config value or default: 1.0 + max_delay: Maximum delay in seconds between retries. If None, uses config value or default: 60.0 + **kwargs: Additional arguments passed to generate() + + Returns: + List of responses in the same order as input prompts + + Example: + >>> adapter = OpenAIAdapter("gpt-3.5-turbo", config={"max_concurrent": 10}) + >>> prompts = ["Hello", "How are you?", "Tell me a joke"] + >>> responses = await adapter.batch_generate(prompts) # Uses max_concurrent=10 from config + >>> responses = await adapter.batch_generate(prompts, max_concurrent=5) # Overrides config + """ + if not prompts: + return [] + + logger = logging.getLogger("benchwise.models") + + # Get values from config if not explicitly provided, with hardcoded defaults as fallback + batch_size = batch_size if batch_size is not None else self.config.get("batch_size", 50) + max_concurrent = max_concurrent if max_concurrent is not None else self.config.get("max_concurrent", 5) + max_retries = max_retries if max_retries is not None else self.config.get("max_retries", 3) + base_delay = base_delay if base_delay is not None else self.config.get("base_delay", 1.0) + max_delay = max_delay if max_delay is not None else self.config.get("max_delay", 60.0) + + # Create semaphore for this call to limit concurrency + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_single_prompt(index: int, prompt: str) -> tuple[int, str]: + """Process a single prompt with retry logic and rate limit handling.""" + async with semaphore: + for attempt in range(max_retries + 1): + try: + # Call the generate method with a single prompt + results = await self.generate([prompt], **kwargs) + return (index, results[0]) + + except Exception as e: + error_str = str(e).lower() + + # Check if this is a rate limit error + is_rate_limit = any([ + "rate limit" in error_str, + "429" in error_str, + "too many requests" in error_str, + "quota" in error_str, + ]) + + if is_rate_limit and attempt < max_retries: + # Calculate exponential backoff with jitter + delay = min(base_delay * (2 ** attempt), max_delay) + jitter = random.uniform(0, 0.1 * delay) + total_delay = delay + jitter + + logger.warning( + f"Rate limit hit for prompt {index}, retrying after {total_delay:.2f}s " + f"(attempt {attempt + 1}/{max_retries})" + ) + await asyncio.sleep(total_delay) + continue + + # Non-rate-limit error or max retries exceeded + error_message = f"Error: {str(e)}" + if attempt == max_retries: + logger.error(f"Max retries exceeded for prompt {index}: {e}") + return (index, error_message) + + # Should never reach here, but just in case + return (index, "Error: Unknown error occurred") + + # Split prompts into batches + batches = [] + for i in range(0, len(prompts), batch_size): + batch = prompts[i:i + batch_size] + batches.append(batch) + + logger.info( + f"Processing {len(prompts)} prompts in {len(batches)} batch(es) " + f"with max {max_concurrent} concurrent requests" + ) + + # Process all prompts concurrently (up to max_concurrent at a time) + all_tasks = [] + for batch_idx, batch in enumerate(batches): + batch_start_idx = batch_idx * batch_size + for i, prompt in enumerate(batch): + task = process_single_prompt(batch_start_idx + i, prompt) + all_tasks.append(task) + + # Gather all results with exception handling + results_with_indices = await asyncio.gather(*all_tasks, return_exceptions=True) + + # Handle any exceptions that weren't caught + processed_results = [] + for item in results_with_indices: + if isinstance(item, Exception): + processed_results.append((len(processed_results), f"Error: {str(item)}")) + else: + processed_results.append(item) + + # Sort by index to maintain original order and extract just the responses + processed_results.sort(key=lambda x: x[0]) + responses = [result[1] for result in processed_results] + + logger.info(f"Completed processing {len(responses)} prompts") + + return responses + class OpenAIAdapter(ModelAdapter): """Adapter for OpenAI models.""" diff --git a/tests/test_batch_generate.py b/tests/test_batch_generate.py new file mode 100644 index 0000000..afd3aba --- /dev/null +++ b/tests/test_batch_generate.py @@ -0,0 +1,413 @@ +import pytest +import asyncio +import time +from unittest.mock import patch, AsyncMock, MagicMock +from benchwise.models import ( + ModelAdapter, + OpenAIAdapter, + AnthropicAdapter, + MockAdapter, + get_model_adapter, +) + + +@pytest.mark.asyncio +class TestBatchGenerate: + """Test batch_generate functionality with concurrency and rate limiting.""" + + async def test_batch_generate_basic(self): + """Test basic batch generation with mock adapter.""" + adapter = MockAdapter("mock-test") + prompts = ["Hello", "How are you?", "Tell me a joke"] + + responses = await adapter.batch_generate(prompts, max_concurrent=2) + + assert len(responses) == len(prompts) + for i, response in enumerate(responses): + assert "Mock response" in response + assert prompts[i][:50] in response + + async def test_batch_generate_empty_prompts(self): + """Test batch generation with empty prompt list.""" + adapter = MockAdapter("mock-test") + responses = await adapter.batch_generate([]) + + assert responses == [] + + async def test_batch_generate_single_prompt(self): + """Test batch generation with a single prompt.""" + adapter = MockAdapter("mock-test") + responses = await adapter.batch_generate(["Single prompt"]) + + assert len(responses) == 1 + assert "Mock response" in responses[0] + + async def test_batch_generate_maintains_order(self): + """Test that batch generation maintains prompt order.""" + adapter = MockAdapter("mock-test") + prompts = [f"Prompt {i}" for i in range(20)] + + responses = await adapter.batch_generate(prompts, max_concurrent=5) + + assert len(responses) == len(prompts) + for i, response in enumerate(responses): + # Check that response corresponds to the correct prompt + assert f"Prompt {i}" in response + + async def test_batch_generate_with_different_batch_sizes(self): + """Test batch generation with various batch sizes.""" + adapter = MockAdapter("mock-test") + prompts = [f"Prompt {i}" for i in range(25)] + + # Test different batch sizes + for batch_size in [5, 10, 30]: + responses = await adapter.batch_generate( + prompts, batch_size=batch_size, max_concurrent=3 + ) + + assert len(responses) == len(prompts) + for i, response in enumerate(responses): + assert f"Prompt {i}" in response + + async def test_batch_generate_concurrency_limit(self): + """Test that max_concurrent properly limits concurrent requests.""" + adapter = MockAdapter("mock-test") + + # Track concurrent executions + concurrent_count = 0 + max_concurrent_seen = 0 + lock = asyncio.Lock() + + # Mock the generate method to track concurrency + original_generate = adapter.generate + + async def tracked_generate(prompts, **kwargs): + nonlocal concurrent_count, max_concurrent_seen + + async with lock: + concurrent_count += 1 + max_concurrent_seen = max(max_concurrent_seen, concurrent_count) + + # Simulate some work + await asyncio.sleep(0.01) + result = await original_generate(prompts, **kwargs) + + async with lock: + concurrent_count -= 1 + + return result + + adapter.generate = tracked_generate + + prompts = [f"Prompt {i}" for i in range(20)] + max_concurrent = 5 + + responses = await adapter.batch_generate( + prompts, max_concurrent=max_concurrent, batch_size=20 + ) + + assert len(responses) == len(prompts) + # The max concurrent should not exceed our limit + assert max_concurrent_seen <= max_concurrent + + async def test_batch_generate_rate_limit_retry(self): + """Test rate limit handling with retry logic.""" + adapter = MockAdapter("mock-test") + + call_count = 0 + + async def failing_then_succeeding_generate(prompts, **kwargs): + nonlocal call_count + call_count += 1 + + # First call fails with rate limit, second succeeds + if call_count == 1: + raise Exception("Rate limit exceeded") + return [f"Success on attempt {call_count}" for _ in prompts] + + adapter.generate = failing_then_succeeding_generate + + responses = await adapter.batch_generate( + ["Test prompt"], max_retries=3, base_delay=0.01 + ) + + assert len(responses) == 1 + assert "Success" in responses[0] + assert call_count == 2 # Failed once, succeeded on second attempt + + async def test_batch_generate_max_retries_exceeded(self): + """Test that errors are returned when max retries exceeded.""" + adapter = MockAdapter("mock-test") + + async def always_failing_generate(prompts, **kwargs): + raise Exception("Rate limit exceeded - always fails") + + adapter.generate = always_failing_generate + + responses = await adapter.batch_generate( + ["Test prompt"], max_retries=2, base_delay=0.01 + ) + + assert len(responses) == 1 + assert "Error:" in responses[0] + assert "Rate limit exceeded" in responses[0] + + async def test_batch_generate_partial_failures(self): + """Test handling of partial failures in batch.""" + adapter = MockAdapter("mock-test") + + call_index = 0 + + async def partially_failing_generate(prompts, **kwargs): + nonlocal call_index + result_index = call_index + call_index += 1 + + # Fail every third prompt + if result_index % 3 == 0: + raise Exception(f"Error for prompt {result_index}") + return [f"Success {result_index}"] + + adapter.generate = partially_failing_generate + + prompts = [f"Prompt {i}" for i in range(10)] + responses = await adapter.batch_generate( + prompts, max_concurrent=2, max_retries=0 + ) + + assert len(responses) == len(prompts) + + # Check that failures are in the correct positions + for i, response in enumerate(responses): + if i % 3 == 0: + assert "Error:" in response + else: + assert "Success" in response + + async def test_batch_generate_exponential_backoff(self): + """Test exponential backoff timing for rate limits.""" + adapter = MockAdapter("mock-test") + + attempt_times = [] + + async def rate_limited_generate(prompts, **kwargs): + attempt_times.append(time.time()) + if len(attempt_times) < 3: + raise Exception("Rate limit exceeded") + return ["Success"] + + adapter.generate = rate_limited_generate + + start_time = time.time() + responses = await adapter.batch_generate( + ["Test prompt"], max_retries=3, base_delay=0.1, max_delay=1.0 + ) + end_time = time.time() + + assert len(responses) == 1 + assert "Success" in responses[0] + + # Check that exponential backoff was applied + # With base_delay=0.1, we expect delays of ~0.1, ~0.2 + total_time = end_time - start_time + assert total_time >= 0.3 # At least 0.1 + 0.2 = 0.3 seconds + + async def test_batch_generate_with_kwargs(self): + """Test that kwargs are passed through to generate method.""" + adapter = MockAdapter("mock-test") + + received_kwargs = {} + + async def tracking_generate(prompts, **kwargs): + received_kwargs.update(kwargs) + return [f"Response for {p}" for p in prompts] + + adapter.generate = tracking_generate + + responses = await adapter.batch_generate( + ["Test"], + temperature=0.8, + max_tokens=500, + top_p=0.9 + ) + + assert len(responses) == 1 + assert received_kwargs["temperature"] == 0.8 + assert received_kwargs["max_tokens"] == 500 + assert received_kwargs["top_p"] == 0.9 + + async def test_batch_generate_large_dataset(self): + """Test batch generation with a large number of prompts.""" + adapter = MockAdapter("mock-test") + + # Generate 100 prompts + prompts = [f"Prompt {i}" for i in range(100)] + + start_time = time.time() + responses = await adapter.batch_generate( + prompts, batch_size=20, max_concurrent=10 + ) + end_time = time.time() + + assert len(responses) == len(prompts) + + # Verify order is maintained + for i, response in enumerate(responses): + assert f"Prompt {i}" in response + + # Should complete reasonably quickly due to concurrency + print(f"Processed {len(prompts)} prompts in {end_time - start_time:.2f}s") + + async def test_batch_generate_different_adapters(self): + """Test batch generation works with different adapter types.""" + adapters = [ + MockAdapter("mock-test"), + ] + + for adapter in adapters: + prompts = ["Test 1", "Test 2", "Test 3"] + responses = await adapter.batch_generate(prompts, max_concurrent=2) + + assert len(responses) == len(prompts) + for response in responses: + assert len(response) > 0 + + async def test_batch_generate_with_factory(self): + """Test batch generation with models from factory function.""" + adapter = get_model_adapter("mock-test") + + prompts = ["Hello", "World"] + responses = await adapter.batch_generate(prompts, max_concurrent=2) + + assert len(responses) == len(prompts) + + async def test_batch_generate_non_rate_limit_errors(self): + """Test that non-rate-limit errors are not retried.""" + adapter = MockAdapter("mock-test") + + call_count = 0 + + async def failing_generate(prompts, **kwargs): + nonlocal call_count + call_count += 1 + raise ValueError("Invalid input") # Non-rate-limit error + + adapter.generate = failing_generate + + responses = await adapter.batch_generate( + ["Test"], max_retries=3, base_delay=0.01 + ) + + assert len(responses) == 1 + assert "Error:" in responses[0] + # Should only try once since it's not a rate limit error + assert call_count == 1 + + async def test_batch_generate_429_status_code(self): + """Test detection of 429 status code in error messages.""" + adapter = MockAdapter("mock-test") + + call_count = 0 + + async def status_429_generate(prompts, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise Exception("HTTP Error 429: Too Many Requests") + return ["Success"] + + adapter.generate = status_429_generate + + responses = await adapter.batch_generate( + ["Test"], max_retries=2, base_delay=0.01 + ) + + assert len(responses) == 1 + assert "Success" in responses[0] + assert call_count == 2 # Retried due to 429 + + async def test_batch_generate_quota_exceeded(self): + """Test detection of quota exceeded errors.""" + adapter = MockAdapter("mock-test") + + call_count = 0 + + async def quota_exceeded_generate(prompts, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise Exception("Quota exceeded for requests") + return ["Success"] + + adapter.generate = quota_exceeded_generate + + responses = await adapter.batch_generate( + ["Test"], max_retries=2, base_delay=0.01 + ) + + assert len(responses) == 1 + assert "Success" in responses[0] + assert call_count == 2 # Retried due to quota error + + +@pytest.mark.asyncio +class TestBatchGenerateIntegration: + """Integration tests for batch_generate with mock API responses.""" + + async def test_openai_adapter_batch_generate(self): + """Test batch_generate with OpenAI adapter (mocked).""" + with patch("openai.AsyncOpenAI") as mock_openai: + mock_client = mock_openai.return_value + mock_response = type( + "obj", + (object,), + { + "choices": [ + type( + "obj", + (object,), + {"message": type("obj", (object,), {"content": "Response"})()}, + )() + ] + }, + )() + + mock_client.chat.completions.create = AsyncMock(return_value=mock_response) + + adapter = OpenAIAdapter("gpt-3.5-turbo") + prompts = ["Test 1", "Test 2", "Test 3"] + + responses = await adapter.batch_generate(prompts, max_concurrent=2) + + assert len(responses) == len(prompts) + # Should have been called 3 times (once per prompt) + assert mock_client.chat.completions.create.call_count == 3 + + async def test_anthropic_adapter_batch_generate(self): + """Test batch_generate with Anthropic adapter (mocked).""" + with patch("anthropic.AsyncAnthropic") as mock_anthropic: + mock_client = mock_anthropic.return_value + mock_response = type( + "obj", + (object,), + {"content": [type("obj", (object,), {"text": "Response"})()]}, + )() + + mock_client.messages.create = AsyncMock(return_value=mock_response) + + adapter = AnthropicAdapter("claude-3-sonnet") + prompts = ["Test 1", "Test 2"] + + responses = await adapter.batch_generate(prompts, max_concurrent=2) + + assert len(responses) == len(prompts) + assert mock_client.messages.create.call_count == 2 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) + + + + +