From ffc92d69bf702a5283e32ba74ec04a8e2443ac91 Mon Sep 17 00:00:00 2001 From: Josh Gieringer Date: Wed, 14 Jan 2026 17:05:27 -0700 Subject: [PATCH 1/4] Retry on error but break if broken --- llm_clients/claude_llm.py | 15 +- llm_clients/gemini_llm.py | 15 +- llm_clients/llama_llm.py | 17 ++- llm_clients/llm_interface.py | 163 +++++++++++++++++++++- llm_clients/openai_llm.py | 15 +- tests/unit/llm_clients/test_claude_llm.py | 73 ++++++---- tests/unit/llm_clients/test_gemini_llm.py | 71 ++++++---- tests/unit/llm_clients/test_llama_llm.py | 144 ++++++++++++------- tests/unit/llm_clients/test_openai_llm.py | 69 +++++---- 9 files changed, 427 insertions(+), 155 deletions(-) diff --git a/llm_clients/claude_llm.py b/llm_clients/claude_llm.py index 0e9ffc27..78961a72 100644 --- a/llm_clients/claude_llm.py +++ b/llm_clients/claude_llm.py @@ -23,9 +23,10 @@ def __init__( name: str, system_prompt: Optional[str] = None, model_name: Optional[str] = None, + max_retries: int = 10, **kwargs, ): - super().__init__(name, system_prompt) + super().__init__(name, system_prompt, max_retries=max_retries) if not Config.ANTHROPIC_API_KEY: raise ValueError("ANTHROPIC_API_KEY not found in environment variables") @@ -98,7 +99,15 @@ async def generate_response( try: start_time = time.time() - response = await self.llm.ainvoke(messages) + + # Use retry logic for API call + async def _invoke(): + return await self.llm.ainvoke(messages) + + response = await self._retry_with_backoff( + _invoke, + operation_name="generate_response", + ) end_time = time.time() # Extract metadata from response @@ -148,7 +157,7 @@ async def generate_response( "error": str(e), "usage": {}, } - return f"Error generating response: {str(e)}" + raise RuntimeError(f"Error generating response: {str(e)}") from e async def generate_structured_response( self, message: Optional[str], response_model: Type[T] diff --git a/llm_clients/gemini_llm.py b/llm_clients/gemini_llm.py index 953508bc..f1fb2d7c 100644 --- a/llm_clients/gemini_llm.py +++ b/llm_clients/gemini_llm.py @@ -23,9 +23,10 @@ def __init__( name: str, system_prompt: Optional[str] = None, model_name: Optional[str] = None, + max_retries: int = 10, **kwargs, ): - super().__init__(name, system_prompt) + super().__init__(name, system_prompt, max_retries=max_retries) if not Config.GOOGLE_API_KEY: raise ValueError("GOOGLE_API_KEY not found in environment variables") @@ -96,7 +97,15 @@ async def generate_response( try: start_time = time.time() - response = await self.llm.ainvoke(messages) + + # Use retry logic for API call + async def _invoke(): + return await self.llm.ainvoke(messages) + + response = await self._retry_with_backoff( + _invoke, + operation_name="generate_response", + ) end_time = time.time() # Extract metadata from response @@ -157,7 +166,7 @@ async def generate_response( "error": str(e), "usage": {}, } - return f"Error generating response: {str(e)}" + raise RuntimeError(f"Error generating response: {str(e)}") from e def get_last_response_metadata(self) -> Dict[str, Any]: """Get metadata from the last response.""" diff --git a/llm_clients/llama_llm.py b/llm_clients/llama_llm.py index 00b954ef..dc0c20bb 100644 --- a/llm_clients/llama_llm.py +++ b/llm_clients/llama_llm.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Dict, List, Optional from langchain_community.llms import Ollama @@ -21,9 +22,10 @@ def __init__( name: str, system_prompt: Optional[str] = None, model_name: Optional[str] = None, + max_retries: int = 10, **kwargs, ): - super().__init__(name, system_prompt) + super().__init__(name, system_prompt, max_retries=max_retries) # Use provided model name or fall back to config default self.model_name = model_name or Config.get_llama_config()["model"] @@ -59,11 +61,18 @@ async def generate_response( ) # Ollama doesn't have native async support in langchain-community - # So we'll use the synchronous version - response = self.llm.invoke(full_message) + # So we'll use the synchronous version, wrapped in async for retry logic + async def _invoke(): + # Run sync invoke in thread pool to avoid blocking + return await asyncio.to_thread(self.llm.invoke, full_message) + + response = await self._retry_with_backoff( + _invoke, + operation_name="generate_response", + ) return response except Exception as e: - return f"Error generating response: {str(e)}" + raise RuntimeError(f"Error generating response: {str(e)}") from e def set_system_prompt(self, system_prompt: str) -> None: """Set or update the system prompt.""" diff --git a/llm_clients/llm_interface.py b/llm_clients/llm_interface.py index 19361799..71786715 100644 --- a/llm_clients/llm_interface.py +++ b/llm_clients/llm_interface.py @@ -1,5 +1,6 @@ +import asyncio from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Type, TypeVar +from typing import Any, Callable, Dict, List, Optional, Type, TypeVar from pydantic import BaseModel @@ -13,9 +14,12 @@ class LLMInterface(ABC): must support basic text generation and system prompt management. """ - def __init__(self, name: str, system_prompt: Optional[str] = None): + def __init__( + self, name: str, system_prompt: Optional[str] = None, max_retries: int = 10 + ): self.name = name self.system_prompt = system_prompt or "" + self.max_retries = max_retries @abstractmethod async def generate_response( @@ -47,6 +51,161 @@ def get_name(self) -> str: """Get the name of this LLM instance.""" return self.name + def _extract_http_status_code(self, exception: Exception) -> Optional[int]: + """Extract HTTP status code from exception if available. + + LangChain and various HTTP libraries wrap HTTP errors differently. + This method attempts to extract the status code from common + exception types. + """ + # Check for status_code attribute (common in HTTPException) + if hasattr(exception, "status_code"): + status_code = getattr(exception, "status_code") + if status_code is not None: + return int(status_code) + + # Check for response attribute with status_code + if hasattr(exception, "response"): + response = getattr(exception, "response") + if hasattr(response, "status_code"): + status_code = getattr(response, "status_code") + if status_code is not None: + return int(status_code) + if hasattr(response, "status"): + status = getattr(response, "status") + if status is not None: + return int(status) + + # Check for status attribute directly + if hasattr(exception, "status"): + status = getattr(exception, "status") + if status is not None: + return int(status) + + # Check exception message for status codes (fallback) + error_str = str(exception).lower() + for code in [429, 500, 502, 503, 504, 529]: + if f"status {code}" in error_str or f"status_code {code}" in error_str: + return code + + return None + + def _extract_retry_after(self, exception: Exception) -> Optional[int]: + """Extract Retry-After header value from exception if available.""" + if hasattr(exception, "response"): + response = getattr(exception, "response") + if hasattr(response, "headers"): + headers = getattr(response, "headers") + retry_after = headers.get("Retry-After") or headers.get("retry-after") + if retry_after: + try: + return int(retry_after) + except (ValueError, TypeError): + pass + return None + + async def _retry_with_backoff( + self, + func: Callable[[], Any], + operation_name: str = "operation", + ) -> Any: + """Execute a function with retry logic for transient HTTP errors. + + Handles the following HTTP status codes: + - 429 (Too Many Requests): Respects Retry-After header, + otherwise exponential backoff + - 500 (Internal Server Error): Retry 1-3 times with + exponential backoff + - 502 (Bad Gateway): Retry 1-3 times with exponential backoff + - 503 (Service Unavailable): Exponential backoff + - 504 (Gateway Timeout): Exponential backoff + - 529 (Overloaded - Anthropic): Treated like 503 with + exponential backoff + + Args: + func: Async function to execute + operation_name: Name of operation for error messages + + Returns: + Result of func() + + Raises: + RuntimeError: If max retries exceeded or non-retryable + error occurs + """ + retryable_status_codes = {429, 500, 502, 503, 504, 529} + max_retries_for_500_502 = 3 # Limit retries for 500/502 + + last_exception = None + + for attempt in range(self.max_retries): + try: + return await func() + except Exception as e: + last_exception = e + status_code = self._extract_http_status_code(e) + + # If we can't determine status code, check if it's + # retryable by message + if status_code is None: + error_str = str(e).lower() + # Check for common retryable error messages + retryable_keywords = [ + "rate limit", + "too many requests", + "service unavailable", + "internal server error", + "bad gateway", + "gateway timeout", + "overloaded", + "timeout", + ] + if any(keyword in error_str for keyword in retryable_keywords): + # Treat as retryable, use exponential backoff + status_code = 503 # Default for unknown retryable + else: + # Non-retryable error, raise immediately + raise RuntimeError( + f"Error in {operation_name}: {str(e)}" + ) from e + + # Check if this is a retryable status code + if status_code not in retryable_status_codes: + # Non-retryable error, raise immediately + raise RuntimeError(f"Error in {operation_name}: {str(e)}") from e + + # For 500 and 502, limit retries to max_retries_for_500_502 + if status_code in {500, 502} and attempt >= max_retries_for_500_502 - 1: + raise RuntimeError( + f"Error in {operation_name} after " + f"{max_retries_for_500_502} retries: {str(e)}" + ) from e + + # Calculate wait time + if status_code == 429: + # Check for Retry-After header + retry_after = self._extract_retry_after(e) + if retry_after is not None: + wait_time = retry_after + else: + # Exponential backoff: 2^attempt seconds, max 60s + wait_time = min(2**attempt, 60) + elif status_code in {503, 529}: + # Exponential backoff for capacity issues + wait_time = min(2**attempt, 60) + else: # 500, 502, 504 + # Exponential backoff for transient errors + wait_time = min(2**attempt, 60) + + # Wait before retrying + await asyncio.sleep(wait_time) + + # Max retries exceeded + raise RuntimeError( + f"Error in {operation_name} after {self.max_retries} retries: " + f"{str(last_exception)}" + ) from last_exception + def __getattr__(self, name): """Delegate attribute access to the underlying llm object. diff --git a/llm_clients/openai_llm.py b/llm_clients/openai_llm.py index 3b62ba51..64928a13 100644 --- a/llm_clients/openai_llm.py +++ b/llm_clients/openai_llm.py @@ -23,9 +23,10 @@ def __init__( name: str, system_prompt: Optional[str] = None, model_name: Optional[str] = None, + max_retries: int = 10, **kwargs, ): - super().__init__(name, system_prompt) + super().__init__(name, system_prompt, max_retries=max_retries) if not Config.OPENAI_API_KEY: raise ValueError("OPENAI_API_KEY not found in environment variables") @@ -95,7 +96,15 @@ async def generate_response( try: start_time = time.time() - response = await self.llm.ainvoke(messages) + + # Use retry logic for API call + async def _invoke(): + return await self.llm.ainvoke(messages) + + response = await self._retry_with_backoff( + _invoke, + operation_name="generate_response", + ) end_time = time.time() # Extract metadata from response - capturing all available fields @@ -177,7 +186,7 @@ async def generate_response( "system_fingerprint": None, "logprobs": None, } - return f"Error generating response: {str(e)}" + raise RuntimeError(f"Error generating response: {str(e)}") from e async def generate_structured_response( self, message: Optional[str], response_model: Type[T] diff --git a/tests/unit/llm_clients/test_claude_llm.py b/tests/unit/llm_clients/test_claude_llm.py index c7700cf4..fe7b56fb 100644 --- a/tests/unit/llm_clients/test_claude_llm.py +++ b/tests/unit/llm_clients/test_claude_llm.py @@ -14,7 +14,7 @@ class TestClaudeLLM: def test_init_missing_api_key_raises_error(self): """Test that missing ANTHROPIC_API_KEY raises ValueError (line 25).""" with pytest.raises(ValueError) as exc_info: - ClaudeLLM(name="TestClaude") + ClaudeLLM(name="TestClaude", max_retries=1) assert "ANTHROPIC_API_KEY not found" in str(exc_info.value) @@ -26,7 +26,7 @@ def test_init_with_default_model(self, mock_chat_anthropic): mock_llm.model = "claude-3-5-sonnet-20241022" mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude", system_prompt="Test prompt") + llm = ClaudeLLM(name="TestClaude", max_retries=1, system_prompt="Test prompt") assert llm.name == "TestClaude" assert llm.system_prompt == "Test prompt" @@ -41,7 +41,9 @@ def test_init_with_custom_model(self, mock_chat_anthropic): mock_llm.model = "claude-3-opus-20240229" mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude", model_name="claude-3-opus-20240229") + llm = ClaudeLLM( + name="TestClaude", max_retries=1, model_name="claude-3-opus-20240229" + ) assert llm.model_name == "claude-3-opus-20240229" @@ -53,7 +55,9 @@ def test_init_with_kwargs(self, mock_chat_anthropic): mock_llm.model = "claude-3-5-sonnet-20241022" mock_chat_anthropic.return_value = mock_llm - ClaudeLLM(name="TestClaude", temperature=0.5, max_tokens=500, top_p=0.9) + ClaudeLLM( + name="TestClaude", max_retries=1, temperature=0.5, max_tokens=500, top_p=0.9 + ) # Verify kwargs were passed to ChatAnthropic call_kwargs = mock_chat_anthropic.call_args[1] @@ -84,7 +88,11 @@ async def test_generate_response_success_with_system_prompt( mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude", system_prompt="You are a helpful assistant.") + llm = ClaudeLLM( + name="TestClaude", + max_retries=1, + system_prompt="You are a helpful assistant.", + ) response = await llm.generate_response( conversation_history=[ {"turn": 0, "speaker": "system", "response": "Hello, Claude!"} @@ -122,7 +130,7 @@ async def test_generate_response_without_system_prompt(self, mock_chat_anthropic mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude") # No system prompt + llm = ClaudeLLM(name="TestClaude", max_retries=1) # No system prompt response = await llm.generate_response( conversation_history=[ {"turn": 0, "speaker": "system", "response": "Test message"} @@ -153,7 +161,7 @@ async def test_generate_response_without_usage_metadata(self, mock_chat_anthropi mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude") + llm = ClaudeLLM(name="TestClaude", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -181,7 +189,7 @@ async def test_generate_response_without_response_metadata( mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude") + llm = ClaudeLLM(name="TestClaude", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -204,16 +212,21 @@ async def test_generate_response_api_error(self, mock_chat_anthropic): mock_llm.ainvoke = AsyncMock(side_effect=Exception("API rate limit exceeded")) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude") - response = await llm.generate_response( - conversation_history=[ - {"turn": 0, "speaker": "system", "response": "Test message"} - ] - ) - - # Should return error message instead of raising - assert "Error generating response" in response - assert "API rate limit exceeded" in response + llm = ClaudeLLM(name="TestClaude", max_retries=1) + error = None + try: + _ = await llm.generate_response( + conversation_history=[ + {"turn": 0, "speaker": "system", "response": "Test message"} + ] + ) + except Exception as e: + error = str(e) + + # Should raise exception with error message + assert error is not None + assert "Error generating response" in error + assert "API rate limit exceeded" in error # Verify error metadata was stored (lines 100-107) metadata = llm.get_last_response_metadata() @@ -241,7 +254,7 @@ async def test_generate_response_tracks_timing(self, mock_chat_anthropic): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude") + llm = ClaudeLLM(name="TestClaude", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -259,7 +272,7 @@ def test_get_last_response_metadata_returns_copy(self): mock_llm.model = "claude-3-5-sonnet-20241022" mock_chat.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude") + llm = ClaudeLLM(name="TestClaude", max_retries=1) llm.last_response_metadata = {"test": "value"} metadata1 = llm.get_last_response_metadata() @@ -281,7 +294,9 @@ def test_set_system_prompt(self): mock_llm.model = "claude-3-5-sonnet-20241022" mock_chat.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude", system_prompt="Initial prompt") + llm = ClaudeLLM( + name="TestClaude", max_retries=1, system_prompt="Initial prompt" + ) assert llm.system_prompt == "Initial prompt" llm.set_system_prompt("Updated prompt") @@ -309,7 +324,7 @@ async def test_generate_response_with_partial_usage_metadata( mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude") + llm = ClaudeLLM(name="TestClaude", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -336,7 +351,7 @@ async def test_metadata_includes_response_object(self, mock_chat_anthropic): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude") + llm = ClaudeLLM(name="TestClaude", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -361,7 +376,7 @@ async def test_timestamp_format(self, mock_chat_anthropic): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude") + llm = ClaudeLLM(name="TestClaude", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -397,7 +412,7 @@ async def test_metadata_with_stop_reason(self, mock_chat_anthropic): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude") + llm = ClaudeLLM(name="TestClaude", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -425,7 +440,7 @@ async def test_raw_metadata_stored(self, mock_chat_anthropic): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude") + llm = ClaudeLLM(name="TestClaude", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -454,7 +469,7 @@ async def test_generate_response_with_conversation_history( mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude", system_prompt="Test") + llm = ClaudeLLM(name="TestClaude", max_retries=1, system_prompt="Test") # Provide conversation history including the current turn history = [ @@ -511,7 +526,7 @@ async def test_generate_response_with_empty_conversation_history( mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude", system_prompt="Test") + llm = ClaudeLLM(name="TestClaude", max_retries=1, system_prompt="Test") response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Hello"}] @@ -542,7 +557,7 @@ async def test_generate_response_with_none_conversation_history( mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_anthropic.return_value = mock_llm - llm = ClaudeLLM(name="TestClaude", system_prompt="Test") + llm = ClaudeLLM(name="TestClaude", max_retries=1, system_prompt="Test") response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Hello"}] diff --git a/tests/unit/llm_clients/test_gemini_llm.py b/tests/unit/llm_clients/test_gemini_llm.py index 7b8f8ddc..df1b404e 100644 --- a/tests/unit/llm_clients/test_gemini_llm.py +++ b/tests/unit/llm_clients/test_gemini_llm.py @@ -14,7 +14,7 @@ class TestGeminiLLM: def test_init_missing_api_key_raises_error(self): """Test that missing GOOGLE_API_KEY raises ValueError (line 25).""" with pytest.raises(ValueError) as exc_info: - GeminiLLM(name="TestGemini") + GeminiLLM(name="TestGemini", max_retries=1) assert "GOOGLE_API_KEY not found" in str(exc_info.value) @@ -25,7 +25,7 @@ def test_init_with_default_model(self, mock_chat_gemini): mock_llm = MagicMock() mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini", system_prompt="Test prompt") + llm = GeminiLLM(name="TestGemini", max_retries=1, system_prompt="Test prompt") assert llm.name == "TestGemini" assert llm.system_prompt == "Test prompt" @@ -39,7 +39,7 @@ def test_init_with_custom_model(self, mock_chat_gemini): mock_llm = MagicMock() mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini", model_name="gemini-1.5-flash") + llm = GeminiLLM(name="TestGemini", max_retries=1, model_name="gemini-1.5-flash") assert llm.model_name == "gemini-1.5-flash" @@ -50,7 +50,9 @@ def test_init_with_kwargs(self, mock_chat_gemini): mock_llm = MagicMock() mock_chat_gemini.return_value = mock_llm - GeminiLLM(name="TestGemini", temperature=0.5, max_tokens=500, top_p=0.9) + GeminiLLM( + name="TestGemini", max_retries=1, temperature=0.5, max_tokens=500, top_p=0.9 + ) # Verify kwargs were passed to ChatGoogleGenerativeAI call_kwargs = mock_chat_gemini.call_args[1] @@ -95,7 +97,11 @@ async def test_generate_response_success_with_system_prompt(self, mock_chat_gemi mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini", system_prompt="You are a helpful assistant.") + llm = GeminiLLM( + name="TestGemini", + max_retries=1, + system_prompt="You are a helpful assistant.", + ) response = await llm.generate_response( conversation_history=[ {"turn": 0, "speaker": "system", "response": "Hello, Gemini!"} @@ -132,7 +138,7 @@ async def test_generate_response_without_system_prompt(self, mock_chat_gemini): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini") # No system prompt + llm = GeminiLLM(name="TestGemini", max_retries=1) # No system prompt response = await llm.generate_response( conversation_history=[ {"turn": 0, "speaker": "system", "response": "Test message"} @@ -168,7 +174,7 @@ async def test_generate_response_with_fallback_token_usage(self, mock_chat_gemin mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini") + llm = GeminiLLM(name="TestGemini", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -195,7 +201,7 @@ async def test_generate_response_without_usage_metadata(self, mock_chat_gemini): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini") + llm = GeminiLLM(name="TestGemini", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -219,7 +225,7 @@ async def test_generate_response_without_response_metadata(self, mock_chat_gemin mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini") + llm = GeminiLLM(name="TestGemini", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -241,16 +247,21 @@ async def test_generate_response_api_error(self, mock_chat_gemini): mock_llm.ainvoke = AsyncMock(side_effect=Exception("API quota exceeded")) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini") - response = await llm.generate_response( - conversation_history=[ - {"turn": 0, "speaker": "system", "response": "Test message"} - ] - ) - - # Should return error message instead of raising - assert "Error generating response" in response - assert "API quota exceeded" in response + llm = GeminiLLM(name="TestGemini", max_retries=1) + error = None + try: + _ = await llm.generate_response( + conversation_history=[ + {"turn": 0, "speaker": "system", "response": "Test message"} + ] + ) + except Exception as e: + error = str(e) + + # Should raise exception with error message + assert error is not None + assert "Error generating response" in error + assert "API quota exceeded" in error # Verify error metadata was stored metadata = llm.get_last_response_metadata() @@ -277,7 +288,7 @@ async def test_generate_response_tracks_timing(self, mock_chat_gemini): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini") + llm = GeminiLLM(name="TestGemini", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -294,7 +305,7 @@ def test_get_last_response_metadata_returns_copy(self): mock_llm = MagicMock() mock_chat.return_value = mock_llm - llm = GeminiLLM(name="TestGemini") + llm = GeminiLLM(name="TestGemini", max_retries=1) llm.last_response_metadata = {"test": "value"} metadata1 = llm.get_last_response_metadata() @@ -315,7 +326,9 @@ def test_set_system_prompt(self): mock_llm = MagicMock() mock_chat.return_value = mock_llm - llm = GeminiLLM(name="TestGemini", system_prompt="Initial prompt") + llm = GeminiLLM( + name="TestGemini", max_retries=1, system_prompt="Initial prompt" + ) assert llm.system_prompt == "Initial prompt" llm.set_system_prompt("Updated prompt") @@ -336,7 +349,7 @@ async def test_metadata_includes_response_object(self, mock_chat_gemini): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini") + llm = GeminiLLM(name="TestGemini", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -360,7 +373,7 @@ async def test_timestamp_format(self, mock_chat_gemini): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini") + llm = GeminiLLM(name="TestGemini", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -395,7 +408,7 @@ async def test_finish_reason_extraction(self, mock_chat_gemini): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini") + llm = GeminiLLM(name="TestGemini", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -422,7 +435,7 @@ async def test_raw_metadata_stored(self, mock_chat_gemini): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini") + llm = GeminiLLM(name="TestGemini", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -453,7 +466,7 @@ async def test_generate_response_with_conversation_history(self, mock_chat_gemin mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini", system_prompt="Test") + llm = GeminiLLM(name="TestGemini", max_retries=1, system_prompt="Test") # Provide conversation history including the current turn history = [ @@ -510,7 +523,7 @@ async def test_generate_response_with_empty_conversation_history( mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini", system_prompt="Test") + llm = GeminiLLM(name="TestGemini", max_retries=1, system_prompt="Test") response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Hi"}] @@ -539,7 +552,7 @@ async def test_generate_response_with_none_conversation_history( mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_gemini.return_value = mock_llm - llm = GeminiLLM(name="TestGemini", system_prompt="Test") + llm = GeminiLLM(name="TestGemini", max_retries=1, system_prompt="Test") response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Hi"}] diff --git a/tests/unit/llm_clients/test_llama_llm.py b/tests/unit/llm_clients/test_llama_llm.py index ba984bd3..b5f4818b 100644 --- a/tests/unit/llm_clients/test_llama_llm.py +++ b/tests/unit/llm_clients/test_llama_llm.py @@ -14,7 +14,7 @@ def test_init_with_default_config(self, mock_ollama): """Test initialization uses default config when no overrides provided.""" from llm_clients.llama_llm import LlamaLLM - LlamaLLM(name="test-llama") + LlamaLLM(name="test-llama", max_retries=1) # Verify Ollama was initialized with default config mock_ollama.assert_called_once() @@ -30,7 +30,7 @@ def test_init_with_custom_model_name(self, mock_ollama): """Test initialization with custom model name.""" from llm_clients.llama_llm import LlamaLLM - llm = LlamaLLM(name="test-llama", model_name="llama3:70b") + llm = LlamaLLM(name="test-llama", max_retries=1, model_name="llama3:70b") call_kwargs = mock_ollama.call_args[1] assert call_kwargs["model"] == "llama3:70b" @@ -41,7 +41,7 @@ def test_init_with_custom_temperature(self, mock_ollama): """Test initialization with custom temperature via kwargs.""" from llm_clients.llama_llm import LlamaLLM - LlamaLLM(name="test-llama", temperature=0.9) + LlamaLLM(name="test-llama", max_retries=1, temperature=0.9) call_kwargs = mock_ollama.call_args[1] assert call_kwargs["temperature"] == 0.9 @@ -52,7 +52,7 @@ def test_init_with_custom_base_url(self, mock_ollama): from llm_clients.llama_llm import LlamaLLM custom_url = "http://remote-server:11434" - LlamaLLM(name="test-llama", base_url=custom_url) + LlamaLLM(name="test-llama", max_retries=1, base_url=custom_url) call_kwargs = mock_ollama.call_args[1] assert call_kwargs["base_url"] == custom_url @@ -62,7 +62,13 @@ def test_init_kwargs_override_defaults(self, mock_ollama): """Test that kwargs override default config values.""" from llm_clients.llama_llm import LlamaLLM - LlamaLLM(name="test-llama", temperature=0.1, top_p=0.95, num_predict=500) + LlamaLLM( + name="test-llama", + max_retries=1, + temperature=0.1, + top_p=0.95, + num_predict=500, + ) call_kwargs = mock_ollama.call_args[1] assert call_kwargs["temperature"] == 0.1 @@ -84,7 +90,7 @@ async def test_generate_response_without_system_prompt(self, mock_ollama): mock_instance.invoke.return_value = "This is a test response" mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama") + llm = LlamaLLM(name="test-llama", max_retries=1) response = await llm.generate_response( conversation_history=[ {"turn": 0, "speaker": "system", "response": "Hello, how are you?"} @@ -107,7 +113,11 @@ async def test_generate_response_with_system_prompt_in_init(self, mock_ollama): mock_instance.invoke.return_value = "I'm doing well, thanks!" mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama", system_prompt="You are a helpful assistant") + llm = LlamaLLM( + name="test-llama", + max_retries=1, + system_prompt="You are a helpful assistant", + ) response = await llm.generate_response( conversation_history=[ {"turn": 0, "speaker": "system", "response": "How are you?"} @@ -131,7 +141,7 @@ async def test_generate_response_with_system_prompt_set_later(self, mock_ollama) mock_instance.invoke.return_value = "Sure, I can help with that" mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama") + llm = LlamaLLM(name="test-llama", max_retries=1) llm.set_system_prompt("You are a coding expert") response = await llm.generate_response( conversation_history=[ @@ -157,16 +167,21 @@ async def test_generate_response_handles_ollama_connection_error(self, mock_olla ) mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama") - response = await llm.generate_response( - conversation_history=[ - {"turn": 0, "speaker": "system", "response": "Test message"} - ] - ) - - # Should return error message, not raise exception - assert "Error generating response" in response - assert "Could not connect to Ollama server" in response + llm = LlamaLLM(name="test-llama", max_retries=1) + error = None + try: + _ = await llm.generate_response( + conversation_history=[ + {"turn": 0, "speaker": "system", "response": "Test message"} + ] + ) + except Exception as e: + error = str(e) + + # Should raise exception with error message + assert error is not None + assert "Error generating response" in error + assert "Could not connect to Ollama server" in error @pytest.mark.asyncio @patch("llm_clients.llama_llm.Ollama") @@ -180,15 +195,22 @@ async def test_generate_response_handles_model_not_found(self, mock_ollama): ) mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama", model_name="nonexistent:latest") - response = await llm.generate_response( - conversation_history=[ - {"turn": 0, "speaker": "system", "response": "Test message"} - ] + llm = LlamaLLM( + name="test-llama", max_retries=1, model_name="nonexistent:latest" ) - - assert "Error generating response" in response - assert "Model 'nonexistent:latest' not found" in response + error = None + try: + _ = await llm.generate_response( + conversation_history=[ + {"turn": 0, "speaker": "system", "response": "Test message"} + ] + ) + except Exception as e: + error = str(e) + + assert error is not None + assert "Error generating response" in error + assert "Model 'nonexistent:latest' not found" in error @pytest.mark.asyncio @patch("llm_clients.llama_llm.Ollama") @@ -200,19 +222,24 @@ async def test_generate_response_handles_timeout_error(self, mock_ollama): mock_instance.invoke.side_effect = TimeoutError("Request timed out after 30s") mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama") - response = await llm.generate_response( - conversation_history=[ - { - "turn": 0, - "speaker": "system", - "response": "Long message that times out", - } - ] - ) - - assert "Error generating response" in response - assert "Request timed out" in response + llm = LlamaLLM(name="test-llama", max_retries=1) + error = None + try: + _ = await llm.generate_response( + conversation_history=[ + { + "turn": 0, + "speaker": "system", + "response": "Long message that times out", + } + ] + ) + except Exception as e: + error = str(e) + + assert error is not None + assert "Error generating response" in error + assert "Request timed out" in error @pytest.mark.asyncio @patch("llm_clients.llama_llm.Ollama") @@ -224,13 +251,20 @@ async def test_generate_response_handles_generic_exception(self, mock_ollama): mock_instance.invoke.side_effect = RuntimeError("Unexpected error occurred") mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama") - response = await llm.generate_response( - conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] - ) - - assert "Error generating response" in response - assert "Unexpected error occurred" in response + llm = LlamaLLM(name="test-llama", max_retries=1) + error = None + try: + _ = await llm.generate_response( + conversation_history=[ + {"turn": 0, "speaker": "system", "response": "Test"} + ] + ) + except Exception as e: + error = str(e) + + assert error is not None + assert "Error generating response" in error + assert "Unexpected error occurred" in error @pytest.mark.asyncio @patch("llm_clients.llama_llm.Ollama") @@ -242,7 +276,7 @@ async def test_generate_response_with_none_message(self, mock_ollama): mock_instance.invoke.return_value = "Default response" mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama") + llm = LlamaLLM(name="test-llama", max_retries=1) response = await llm.generate_response(None) # Should handle None gracefully - message won't include current message part @@ -259,7 +293,7 @@ async def test_generate_response_with_empty_string(self, mock_ollama): mock_instance.invoke.return_value = "Response to empty" mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama") + llm = LlamaLLM(name="test-llama", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": ""}] ) @@ -281,7 +315,7 @@ async def test_generate_response_preserves_multiline_messages(self, mock_ollama) mock_ollama.return_value = mock_instance multiline_msg = "Line 1\nLine 2\nLine 3" - llm = LlamaLLM(name="test-llama", system_prompt="Helper") + llm = LlamaLLM(name="test-llama", max_retries=1, system_prompt="Helper") await llm.generate_response( conversation_history=[ {"turn": 0, "speaker": "system", "response": multiline_msg} @@ -301,7 +335,7 @@ def test_set_system_prompt_updates_prompt(self, mock_ollama): """Test that set_system_prompt updates the system_prompt attribute.""" from llm_clients.llama_llm import LlamaLLM - llm = LlamaLLM(name="test-llama") + llm = LlamaLLM(name="test-llama", max_retries=1) # Initially empty string (from LLMInterface base class) assert llm.system_prompt == "" @@ -324,7 +358,7 @@ async def test_set_system_prompt_affects_subsequent_calls(self, mock_ollama): mock_instance.invoke.return_value = "Response" mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama") + llm = LlamaLLM(name="test-llama", max_retries=1) # First call without system prompt await llm.generate_response( @@ -363,7 +397,9 @@ async def test_generate_response_with_conversation_history(self, mock_ollama): mock_instance.invoke.return_value = "Response with history" mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama", system_prompt="You are helpful") + llm = LlamaLLM( + name="test-llama", max_retries=1, system_prompt="You are helpful" + ) history = [ { @@ -408,7 +444,7 @@ async def test_generate_response_with_empty_conversation_history(self, mock_olla mock_instance.invoke.return_value = "Response" mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama") + llm = LlamaLLM(name="test-llama", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Hello"}] @@ -430,7 +466,7 @@ async def test_generate_response_with_none_conversation_history(self, mock_ollam mock_instance.invoke.return_value = "Response" mock_ollama.return_value = mock_instance - llm = LlamaLLM(name="test-llama") + llm = LlamaLLM(name="test-llama", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] diff --git a/tests/unit/llm_clients/test_openai_llm.py b/tests/unit/llm_clients/test_openai_llm.py index f1b04772..4b38907c 100644 --- a/tests/unit/llm_clients/test_openai_llm.py +++ b/tests/unit/llm_clients/test_openai_llm.py @@ -14,7 +14,7 @@ class TestOpenAILLM: def test_init_missing_api_key_raises_error(self): """Test that missing OPENAI_API_KEY raises ValueError (line 25).""" with pytest.raises(ValueError) as exc_info: - OpenAILLM(name="TestOpenAI") + OpenAILLM(name="TestOpenAI", max_retries=1) assert "OPENAI_API_KEY not found" in str(exc_info.value) @@ -25,7 +25,7 @@ def test_init_with_default_model(self, mock_chat_openai): mock_llm = MagicMock() mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI", system_prompt="Test prompt") + llm = OpenAILLM(name="TestOpenAI", max_retries=1, system_prompt="Test prompt") assert llm.name == "TestOpenAI" assert llm.system_prompt == "Test prompt" @@ -39,7 +39,7 @@ def test_init_with_custom_model(self, mock_chat_openai): mock_llm = MagicMock() mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI", model_name="gpt-4-turbo") + llm = OpenAILLM(name="TestOpenAI", max_retries=1, model_name="gpt-4-turbo") assert llm.model_name == "gpt-4-turbo" @@ -50,7 +50,9 @@ def test_init_with_kwargs(self, mock_chat_openai): mock_llm = MagicMock() mock_chat_openai.return_value = mock_llm - OpenAILLM(name="TestOpenAI", temperature=0.5, max_tokens=500, top_p=0.9) + OpenAILLM( + name="TestOpenAI", max_retries=1, temperature=0.5, max_tokens=500, top_p=0.9 + ) # Verify kwargs were passed to ChatOpenAI call_kwargs = mock_chat_openai.call_args[1] @@ -90,7 +92,11 @@ async def test_generate_response_success_with_system_prompt(self, mock_chat_open mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI", system_prompt="You are a helpful assistant.") + llm = OpenAILLM( + name="TestOpenAI", + max_retries=1, + system_prompt="You are a helpful assistant.", + ) response = await llm.generate_response( conversation_history=[ {"turn": 0, "speaker": "system", "response": "Hello, GPT!"} @@ -131,7 +137,7 @@ async def test_generate_response_without_system_prompt(self, mock_chat_openai): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI") # No system prompt + llm = OpenAILLM(name="TestOpenAI", max_retries=1) # No system prompt response = await llm.generate_response( conversation_history=[ {"turn": 0, "speaker": "system", "response": "Test message"} @@ -161,7 +167,7 @@ async def test_generate_response_without_additional_kwargs(self, mock_chat_opena mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI") + llm = OpenAILLM(name="TestOpenAI", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -184,7 +190,7 @@ async def test_generate_response_without_response_metadata(self, mock_chat_opena mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI") + llm = OpenAILLM(name="TestOpenAI", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -218,7 +224,7 @@ async def test_generate_response_without_usage_metadata(self, mock_chat_openai): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI") + llm = OpenAILLM(name="TestOpenAI", max_retries=1) response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -241,16 +247,21 @@ async def test_generate_response_api_error(self, mock_chat_openai): mock_llm.ainvoke = AsyncMock(side_effect=Exception("API rate limit exceeded")) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI") - response = await llm.generate_response( - conversation_history=[ - {"turn": 0, "speaker": "system", "response": "Test message"} - ] - ) - - # Should return error message instead of raising - assert "Error generating response" in response - assert "API rate limit exceeded" in response + llm = OpenAILLM(name="TestOpenAI", max_retries=1) + error = None + try: + _ = await llm.generate_response( + conversation_history=[ + {"turn": 0, "speaker": "system", "response": "Test message"} + ] + ) + except Exception as e: + error = str(e) + + # Should raise exception with error message + assert error is not None + assert "Error generating response" in error + assert "API rate limit exceeded" in error # Verify error metadata was stored metadata = llm.get_last_response_metadata() @@ -277,7 +288,7 @@ async def test_generate_response_tracks_timing(self, mock_chat_openai): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI") + llm = OpenAILLM(name="TestOpenAI", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -294,7 +305,7 @@ def test_get_last_response_metadata_returns_copy(self): mock_llm = MagicMock() mock_chat.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI") + llm = OpenAILLM(name="TestOpenAI", max_retries=1) llm.last_response_metadata = {"test": "value"} metadata1 = llm.get_last_response_metadata() @@ -315,7 +326,9 @@ def test_set_system_prompt(self): mock_llm = MagicMock() mock_chat.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI", system_prompt="Initial prompt") + llm = OpenAILLM( + name="TestOpenAI", max_retries=1, system_prompt="Initial prompt" + ) assert llm.system_prompt == "Initial prompt" llm.set_system_prompt("Updated prompt") @@ -336,7 +349,7 @@ async def test_metadata_includes_response_object(self, mock_chat_openai): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI") + llm = OpenAILLM(name="TestOpenAI", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -360,7 +373,7 @@ async def test_timestamp_format(self, mock_chat_openai): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI") + llm = OpenAILLM(name="TestOpenAI", max_retries=1) await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -392,7 +405,7 @@ async def test_model_name_update_from_metadata(self, mock_chat_openai): mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI", model_name="gpt-4") + llm = OpenAILLM(name="TestOpenAI", max_retries=1, model_name="gpt-4") await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] ) @@ -421,7 +434,7 @@ async def test_generate_response_with_conversation_history(self, mock_chat_opena mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI", system_prompt="Test") + llm = OpenAILLM(name="TestOpenAI", max_retries=1, system_prompt="Test") # Provide conversation history including the current turn history = [ @@ -478,7 +491,7 @@ async def test_generate_response_with_empty_conversation_history( mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI", system_prompt="Test") + llm = OpenAILLM(name="TestOpenAI", max_retries=1, system_prompt="Test") response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Hi"}] @@ -507,7 +520,7 @@ async def test_generate_response_with_none_conversation_history( mock_llm.ainvoke = AsyncMock(return_value=mock_response) mock_chat_openai.return_value = mock_llm - llm = OpenAILLM(name="TestOpenAI", system_prompt="Test") + llm = OpenAILLM(name="TestOpenAI", max_retries=1, system_prompt="Test") response = await llm.generate_response( conversation_history=[{"turn": 0, "speaker": "system", "response": "Hi"}] From 844ab09f49c295b69a94b553d7edc7b833538553 Mon Sep 17 00:00:00 2001 From: Josh Gieringer Date: Thu, 15 Jan 2026 11:36:31 -0700 Subject: [PATCH 2/4] support retry on empty response + extract error code --- llm_clients/claude_llm.py | 5 + llm_clients/gemini_llm.py | 5 + llm_clients/llama_llm.py | 13 + llm_clients/llm_interface.py | 36 ++- llm_clients/openai_llm.py | 5 + tests/unit/llm_clients/test_claude_llm.py | 106 +++++++++ tests/unit/llm_clients/test_llm_interface.py | 237 ++++++++++++++++++- 7 files changed, 399 insertions(+), 8 deletions(-) diff --git a/llm_clients/claude_llm.py b/llm_clients/claude_llm.py index 78961a72..c9f25ca1 100644 --- a/llm_clients/claude_llm.py +++ b/llm_clients/claude_llm.py @@ -104,9 +104,14 @@ async def generate_response( async def _invoke(): return await self.llm.ainvoke(messages) + def _validate_response(response_obj): + """Validate that response has non-empty content.""" + return bool(response_obj.text and response_obj.text.strip()) + response = await self._retry_with_backoff( _invoke, operation_name="generate_response", + response_validator=_validate_response, ) end_time = time.time() diff --git a/llm_clients/gemini_llm.py b/llm_clients/gemini_llm.py index f1fb2d7c..437da82c 100644 --- a/llm_clients/gemini_llm.py +++ b/llm_clients/gemini_llm.py @@ -102,9 +102,14 @@ async def generate_response( async def _invoke(): return await self.llm.ainvoke(messages) + def _validate_response(response_obj): + """Validate that response has non-empty content.""" + return bool(response_obj.text and response_obj.text.strip()) + response = await self._retry_with_backoff( _invoke, operation_name="generate_response", + response_validator=_validate_response, ) end_time = time.time() diff --git a/llm_clients/llama_llm.py b/llm_clients/llama_llm.py index dc0c20bb..d0b0fbe2 100644 --- a/llm_clients/llama_llm.py +++ b/llm_clients/llama_llm.py @@ -66,9 +66,22 @@ async def _invoke(): # Run sync invoke in thread pool to avoid blocking return await asyncio.to_thread(self.llm.invoke, full_message) + def _validate_response(response_obj): + """Validate that response has non-empty content.""" + # Ollama may return string directly or a message object + if isinstance(response_obj, str): + return bool(response_obj and response_obj.strip()) + elif hasattr(response_obj, "text"): + return bool(response_obj.text and response_obj.text.strip()) + elif hasattr(response_obj, "content"): + return bool(response_obj.content and response_obj.content.strip()) + # If we can't determine, assume valid + return True + response = await self._retry_with_backoff( _invoke, operation_name="generate_response", + response_validator=_validate_response, ) return response except Exception as e: diff --git a/llm_clients/llm_interface.py b/llm_clients/llm_interface.py index 71786715..ee90c39b 100644 --- a/llm_clients/llm_interface.py +++ b/llm_clients/llm_interface.py @@ -1,4 +1,5 @@ import asyncio +import re from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Type, TypeVar @@ -88,6 +89,16 @@ def _extract_http_status_code(self, exception: Exception) -> Optional[int]: if f"status {code}" in error_str or f"status_code {code}" in error_str: return code + # Check for "Error Code" pattern in exception message + if "error code" in error_str: + # Try to extract numeric code after "error code" + match = re.search(r"error code[:\s]+(\d+)", error_str, re.IGNORECASE) + if match: + try: + return int(match.group(1)) + except (ValueError, TypeError): + pass + return None def _extract_retry_after(self, exception: Exception) -> Optional[int]: @@ -108,6 +119,7 @@ async def _retry_with_backoff( self, func: Callable[[], Any], operation_name: str = "operation", + response_validator: Optional[Callable[[Any], bool]] = None, ) -> Any: """Execute a function with retry logic for transient HTTP errors. @@ -122,16 +134,23 @@ async def _retry_with_backoff( - 529 (Overloaded - Anthropic): Treated like 503 with exponential backoff + Also retries if response_validator is provided and returns False + (e.g., for empty response content). + Args: func: Async function to execute operation_name: Name of operation for error messages + response_validator: Optional function to validate response. + If provided and returns False, will retry the operation. + Should accept the result of func() and return True if valid. Returns: Result of func() Raises: RuntimeError: If max retries exceeded or non-retryable - error occurs + error occurs, or if response validation fails after + max retries """ retryable_status_codes = {429, 500, 502, 503, 504, 529} max_retries_for_500_502 = 3 # Limit retries for 500/502 @@ -140,7 +159,18 @@ async def _retry_with_backoff( for attempt in range(self.max_retries): try: - return await func() + result = await func() + + # Validate response if validator is provided + if response_validator is not None: + if not response_validator(result): + # Response validation failed, treat as retryable error + raise ValueError( + f"Response validation failed in {operation_name}: " + "response content is empty or invalid" + ) + + return result except Exception as e: last_exception = e status_code = self._extract_http_status_code(e) @@ -159,6 +189,8 @@ async def _retry_with_backoff( "gateway timeout", "overloaded", "timeout", + "response validation failed", + "response content is empty", ] if any(keyword in error_str for keyword in retryable_keywords): # Treat as retryable, use exponential backoff diff --git a/llm_clients/openai_llm.py b/llm_clients/openai_llm.py index 64928a13..f836e7fe 100644 --- a/llm_clients/openai_llm.py +++ b/llm_clients/openai_llm.py @@ -101,9 +101,14 @@ async def generate_response( async def _invoke(): return await self.llm.ainvoke(messages) + def _validate_response(response_obj): + """Validate that response has non-empty content.""" + return bool(response_obj.text and response_obj.text.strip()) + response = await self._retry_with_backoff( _invoke, operation_name="generate_response", + response_validator=_validate_response, ) end_time = time.time() diff --git a/tests/unit/llm_clients/test_claude_llm.py b/tests/unit/llm_clients/test_claude_llm.py index fe7b56fb..5ad0464d 100644 --- a/tests/unit/llm_clients/test_claude_llm.py +++ b/tests/unit/llm_clients/test_claude_llm.py @@ -571,3 +571,109 @@ async def test_generate_response_with_none_conversation_history( # Should have: SystemMessage + current message only assert len(messages) == 2 + + @pytest.mark.asyncio + @patch("llm_clients.claude_llm.Config.ANTHROPIC_API_KEY", "test-key") + @patch("llm_clients.claude_llm.ChatAnthropic") + async def test_generate_response_retries_on_empty_response( + self, mock_chat_anthropic + ): + """Test that empty response content triggers retry.""" + mock_llm = MagicMock() + mock_llm.model = "claude-3-5-sonnet-20241022" + + call_count = 0 + + async def mock_ainvoke(messages): + nonlocal call_count + call_count += 1 + mock_response = MagicMock() + if call_count == 1: + # First call returns empty response + mock_response.text = "" + else: + # Subsequent calls return valid response + mock_response.text = "Valid response after retry" + mock_response.id = f"msg_{call_count}" + mock_response.response_metadata = {} + return mock_response + + mock_llm.ainvoke = AsyncMock(side_effect=mock_ainvoke) + mock_chat_anthropic.return_value = mock_llm + + llm = ClaudeLLM(name="TestClaude", max_retries=3) + response = await llm.generate_response( + conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] + ) + + assert response == "Valid response after retry" + assert call_count == 2 # Should have retried once + + @pytest.mark.asyncio + @patch("llm_clients.claude_llm.Config.ANTHROPIC_API_KEY", "test-key") + @patch("llm_clients.claude_llm.ChatAnthropic") + async def test_generate_response_retries_on_whitespace_only_response( + self, mock_chat_anthropic + ): + """Test that whitespace-only response triggers retry.""" + mock_llm = MagicMock() + mock_llm.model = "claude-3-5-sonnet-20241022" + + call_count = 0 + + async def mock_ainvoke(messages): + nonlocal call_count + call_count += 1 + mock_response = MagicMock() + if call_count == 1: + # First call returns whitespace-only response + mock_response.text = " \n\t " + else: + # Subsequent calls return valid response + mock_response.text = "Valid response" + mock_response.id = f"msg_{call_count}" + mock_response.response_metadata = {} + return mock_response + + mock_llm.ainvoke = AsyncMock(side_effect=mock_ainvoke) + mock_chat_anthropic.return_value = mock_llm + + llm = ClaudeLLM(name="TestClaude", max_retries=3) + response = await llm.generate_response( + conversation_history=[{"turn": 0, "speaker": "system", "response": "Test"}] + ) + + assert response == "Valid response" + assert call_count == 2 # Should have retried once + + @pytest.mark.asyncio + @patch("llm_clients.claude_llm.Config.ANTHROPIC_API_KEY", "test-key") + @patch("llm_clients.claude_llm.ChatAnthropic") + async def test_generate_response_empty_response_exhausts_retries( + self, mock_chat_anthropic + ): + """Test that empty response raises error after max retries.""" + mock_llm = MagicMock() + mock_llm.model = "claude-3-5-sonnet-20241022" + + async def mock_ainvoke(messages): + mock_response = MagicMock() + mock_response.text = "" # Always empty + mock_response.id = "msg_empty" + mock_response.response_metadata = {} + return mock_response + + mock_llm.ainvoke = AsyncMock(side_effect=mock_ainvoke) + mock_chat_anthropic.return_value = mock_llm + + llm = ClaudeLLM(name="TestClaude", max_retries=2) + + with pytest.raises(RuntimeError) as exc_info: + await llm.generate_response( + conversation_history=[ + {"turn": 0, "speaker": "system", "response": "Test"} + ] + ) + + assert "after 2 retries" in str(exc_info.value) + assert "response content is empty" in str(exc_info.value) diff --git a/tests/unit/llm_clients/test_llm_interface.py b/tests/unit/llm_clients/test_llm_interface.py index f71e2c77..8d5f6ca3 100644 --- a/tests/unit/llm_clients/test_llm_interface.py +++ b/tests/unit/llm_clients/test_llm_interface.py @@ -1,3 +1,4 @@ +from typing import Optional from unittest.mock import MagicMock import pytest @@ -8,8 +9,10 @@ class ConcreteLLM(LLMInterface): """Concrete implementation for testing abstract base class.""" - def __init__(self, name: str, system_prompt: str = None): - super().__init__(name, system_prompt) + def __init__( + self, name: str, system_prompt: Optional[str] = None, max_retries: int = 10 + ): + super().__init__(name, system_prompt, max_retries=max_retries) # Add a mock llm object for __getattr__ testing self.llm = MagicMock(spec=["temperature", "max_tokens", "custom_method"]) self.llm.temperature = 0.7 @@ -77,14 +80,14 @@ def test_set_system_prompt_abstract_method(self): def test_cannot_instantiate_abstract_class(self): """Test that LLMInterface cannot be instantiated directly.""" with pytest.raises(TypeError) as exc_info: - LLMInterface(name="Test") + LLMInterface(name="Test") # type: ignore[abstract] assert "Can't instantiate abstract class" in str(exc_info.value) def test_incomplete_implementation_raises_error(self): """Test that incomplete implementations raise TypeError.""" with pytest.raises(TypeError) as exc_info: - IncompleteLLM(name="Incomplete") + IncompleteLLM(name="Incomplete") # type: ignore[abstract] assert "Can't instantiate abstract class" in str(exc_info.value) @@ -132,7 +135,7 @@ def test_getattr_with_none_llm(self): class NullLLM(LLMInterface): """Implementation with None llm.""" - def __init__(self, name: str, system_prompt: str = None): + def __init__(self, name: str, system_prompt: Optional[str] = None): super().__init__(name, system_prompt) self.llm = None @@ -184,7 +187,7 @@ def test_getattr_preserves_attribute_type(self): # Create a fresh mock without spec for this test class FlexibleLLM(LLMInterface): - def __init__(self, name: str, system_prompt: str = None): + def __init__(self, name: str, system_prompt: Optional[str] = None): super().__init__(name, system_prompt) self.llm = MagicMock() self.llm.string_attr = "test string" @@ -206,3 +209,225 @@ def set_system_prompt(self, system_prompt: str) -> None: assert isinstance(llm.float_attr, float) assert isinstance(llm.bool_attr, bool) assert isinstance(llm.list_attr, list) + + +@pytest.mark.unit +class TestLLMInterfaceRetryLogic: + """Unit tests for retry logic and error handling in LLMInterface.""" + + def test_extract_http_status_code_from_status_code_attribute(self): + """Test extracting status code from exception.status_code attribute.""" + llm = ConcreteLLM(name="TestLLM") + + class ExceptionWithStatusCode(Exception): + def __init__(self, status_code): + self.status_code = status_code + super().__init__(f"HTTP {status_code}") + + exc = ExceptionWithStatusCode(429) + assert llm._extract_http_status_code(exc) == 429 + + def test_extract_http_status_code_from_response_attribute(self): + """Test extracting status code from exception.response.status_code.""" + llm = ConcreteLLM(name="TestLLM") + + class MockResponse: + def __init__(self, status_code): + self.status_code = status_code + + class ExceptionWithResponse(Exception): + def __init__(self, status_code): + self.response = MockResponse(status_code) + super().__init__(f"HTTP {status_code}") + + exc = ExceptionWithResponse(503) + assert llm._extract_http_status_code(exc) == 503 + + def test_extract_http_status_code_from_error_message(self): + """Test extracting status code from error message string.""" + llm = ConcreteLLM(name="TestLLM") + + exc = Exception("HTTP status 429: Too Many Requests") + assert llm._extract_http_status_code(exc) == 429 + + exc2 = Exception("Request failed with status_code 503") + assert llm._extract_http_status_code(exc2) == 503 + + def test_extract_http_status_code_error_code_pattern(self): + """Test extracting status code from 'Error Code' pattern.""" + llm = ConcreteLLM(name="TestLLM") + + exc = Exception("Error Code: 429") + assert llm._extract_http_status_code(exc) == 429 + + exc2 = Exception("Error Code 503 occurred") + assert llm._extract_http_status_code(exc2) == 503 + + exc3 = Exception("Error code: 500") + assert llm._extract_http_status_code(exc3) == 500 + + def test_extract_http_status_code_error_code_with_additional_text(self): + """Test extracting status code from 'Error code' with additional text after.""" + llm = ConcreteLLM(name="TestLLM") + + # Real-world example from Azure API + error_msg = ( + "Error code: 400 - {'type': 'error', 'error': " + "{'type': 'invalid_request_error', 'message': " + "'messages.2: all messages must have non-empty content except for " + "the optional final assistant message'}, 'request_id': " + "'req_011CX84UXrNZGnUz2i9YM7AX'}" + ) + exc = Exception(error_msg) + assert llm._extract_http_status_code(exc) == 400 + + # Test with different formats + exc2 = Exception("Error code: 429 - Rate limit exceeded") + assert llm._extract_http_status_code(exc2) == 429 + + exc3 = Exception("Error code: 503 Service unavailable") + assert llm._extract_http_status_code(exc3) == 503 + + def test_extract_http_status_code_no_match(self): + """Test that None is returned when no status code can be extracted.""" + llm = ConcreteLLM(name="TestLLM") + + exc = Exception("Generic error message") + assert llm._extract_http_status_code(exc) is None + + @pytest.mark.asyncio + async def test_retry_with_backoff_empty_response_retries(self): + """Test that empty response content triggers retry.""" + llm = ConcreteLLM(name="TestLLM", max_retries=3) + + call_count = 0 + + async def func_with_empty_then_valid(): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call returns empty response + mock_response = MagicMock() + mock_response.text = "" + return mock_response + else: + # Subsequent calls return valid response + mock_response = MagicMock() + mock_response.text = "Valid response" + return mock_response + + def validator(response_obj): + """Validate that response has non-empty content.""" + return bool(response_obj.text and response_obj.text.strip()) + + result = await llm._retry_with_backoff( + func_with_empty_then_valid, + operation_name="test_operation", + response_validator=validator, + ) + + assert result.text == "Valid response" + assert call_count == 2 # Should have retried once + + @pytest.mark.asyncio + async def test_retry_with_backoff_empty_response_exhausts_retries(self): + """Test that empty response raises error after max retries.""" + llm = ConcreteLLM(name="TestLLM", max_retries=2) + + async def func_always_empty(): + mock_response = MagicMock() + mock_response.text = "" + return mock_response + + def validator(response_obj): + """Validate that response has non-empty content.""" + return bool(response_obj.text and response_obj.text.strip()) + + with pytest.raises(RuntimeError) as exc_info: + await llm._retry_with_backoff( + func_always_empty, + operation_name="test_operation", + response_validator=validator, + ) + + assert "after 2 retries" in str(exc_info.value) + assert "response content is empty" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_retry_with_backoff_whitespace_only_response_retries(self): + """Test that whitespace-only response triggers retry.""" + llm = ConcreteLLM(name="TestLLM", max_retries=3) + + call_count = 0 + + async def func_with_whitespace_then_valid(): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call returns whitespace-only response + mock_response = MagicMock() + mock_response.text = " \n\t " + return mock_response + else: + # Subsequent calls return valid response + mock_response = MagicMock() + mock_response.text = "Valid response" + return mock_response + + def validator(response_obj): + """Validate that response has non-empty content.""" + return bool(response_obj.text and response_obj.text.strip()) + + result = await llm._retry_with_backoff( + func_with_whitespace_then_valid, + operation_name="test_operation", + response_validator=validator, + ) + + assert result.text == "Valid response" + assert call_count == 2 # Should have retried once + + @pytest.mark.asyncio + async def test_retry_with_backoff_no_validator_passes_through(self): + """Test that without validator, empty response is returned.""" + llm = ConcreteLLM(name="TestLLM", max_retries=3) + + async def func_with_empty(): + mock_response = MagicMock() + mock_response.text = "" + return mock_response + + # No validator provided + result = await llm._retry_with_backoff( + func_with_empty, + operation_name="test_operation", + ) + + assert result.text == "" # Empty response is returned without validation + + @pytest.mark.asyncio + async def test_retry_with_backoff_validator_returns_true_immediately(self): + """Test that valid response passes validator immediately.""" + llm = ConcreteLLM(name="TestLLM", max_retries=3) + + call_count = 0 + + async def func_with_valid(): + nonlocal call_count + call_count += 1 + mock_response = MagicMock() + mock_response.text = "Valid response" + return mock_response + + def validator(response_obj): + """Validate that response has non-empty content.""" + return bool(response_obj.text and response_obj.text.strip()) + + result = await llm._retry_with_backoff( + func_with_valid, + operation_name="test_operation", + response_validator=validator, + ) + + assert result.text == "Valid response" + assert call_count == 1 # Should not retry for valid response From f71163e53dd00756c2cc1e7d604b266fef0ec585 Mon Sep 17 00:00:00 2001 From: Josh Gieringer Date: Thu, 15 Jan 2026 15:38:26 -0700 Subject: [PATCH 3/4] 3 retries is sufficient --- llm_clients/claude_llm.py | 2 +- llm_clients/gemini_llm.py | 2 +- llm_clients/llama_llm.py | 2 +- llm_clients/llm_interface.py | 2 +- llm_clients/openai_llm.py | 2 +- tests/unit/llm_clients/test_llm_interface.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llm_clients/claude_llm.py b/llm_clients/claude_llm.py index c9f25ca1..be2b86ed 100644 --- a/llm_clients/claude_llm.py +++ b/llm_clients/claude_llm.py @@ -23,7 +23,7 @@ def __init__( name: str, system_prompt: Optional[str] = None, model_name: Optional[str] = None, - max_retries: int = 10, + max_retries: int = 3, **kwargs, ): super().__init__(name, system_prompt, max_retries=max_retries) diff --git a/llm_clients/gemini_llm.py b/llm_clients/gemini_llm.py index 437da82c..0e549a46 100644 --- a/llm_clients/gemini_llm.py +++ b/llm_clients/gemini_llm.py @@ -23,7 +23,7 @@ def __init__( name: str, system_prompt: Optional[str] = None, model_name: Optional[str] = None, - max_retries: int = 10, + max_retries: int = 3, **kwargs, ): super().__init__(name, system_prompt, max_retries=max_retries) diff --git a/llm_clients/llama_llm.py b/llm_clients/llama_llm.py index d0b0fbe2..6af67778 100644 --- a/llm_clients/llama_llm.py +++ b/llm_clients/llama_llm.py @@ -22,7 +22,7 @@ def __init__( name: str, system_prompt: Optional[str] = None, model_name: Optional[str] = None, - max_retries: int = 10, + max_retries: int = 3, **kwargs, ): super().__init__(name, system_prompt, max_retries=max_retries) diff --git a/llm_clients/llm_interface.py b/llm_clients/llm_interface.py index ee90c39b..f81157f8 100644 --- a/llm_clients/llm_interface.py +++ b/llm_clients/llm_interface.py @@ -16,7 +16,7 @@ class LLMInterface(ABC): """ def __init__( - self, name: str, system_prompt: Optional[str] = None, max_retries: int = 10 + self, name: str, system_prompt: Optional[str] = None, max_retries: int = 3 ): self.name = name self.system_prompt = system_prompt or "" diff --git a/llm_clients/openai_llm.py b/llm_clients/openai_llm.py index f836e7fe..c9c3eaea 100644 --- a/llm_clients/openai_llm.py +++ b/llm_clients/openai_llm.py @@ -23,7 +23,7 @@ def __init__( name: str, system_prompt: Optional[str] = None, model_name: Optional[str] = None, - max_retries: int = 10, + max_retries: int = 3, **kwargs, ): super().__init__(name, system_prompt, max_retries=max_retries) diff --git a/tests/unit/llm_clients/test_llm_interface.py b/tests/unit/llm_clients/test_llm_interface.py index 8d5f6ca3..bcde94a6 100644 --- a/tests/unit/llm_clients/test_llm_interface.py +++ b/tests/unit/llm_clients/test_llm_interface.py @@ -10,7 +10,7 @@ class ConcreteLLM(LLMInterface): """Concrete implementation for testing abstract base class.""" def __init__( - self, name: str, system_prompt: Optional[str] = None, max_retries: int = 10 + self, name: str, system_prompt: Optional[str] = None, max_retries: int = 3 ): super().__init__(name, system_prompt, max_retries=max_retries) # Add a mock llm object for __getattr__ testing From 63e4d9c92ebe0c922ebcf98733abb77d6ee9948c Mon Sep 17 00:00:00 2001 From: Josh Gieringer Date: Thu, 15 Jan 2026 15:49:55 -0700 Subject: [PATCH 4/4] way more testing --- tests/unit/llm_clients/test_llm_interface.py | 476 ++++++++++++++++++- 1 file changed, 471 insertions(+), 5 deletions(-) diff --git a/tests/unit/llm_clients/test_llm_interface.py b/tests/unit/llm_clients/test_llm_interface.py index bcde94a6..76b906ea 100644 --- a/tests/unit/llm_clients/test_llm_interface.py +++ b/tests/unit/llm_clients/test_llm_interface.py @@ -1,11 +1,20 @@ -from typing import Optional -from unittest.mock import MagicMock +from typing import Any, Optional +from unittest.mock import AsyncMock, MagicMock, patch import pytest from llm_clients.llm_interface import LLMInterface +class ExceptionWithStatusCode(Exception): + """Exception with status_code attribute for testing.""" + + def __init__(self, status_code: int, message: str = ""): + self.status_code = status_code + self.response: Any = None # Can be set for testing Retry-After header + super().__init__(message or f"HTTP {status_code}") + + class ConcreteLLM(LLMInterface): """Concrete implementation for testing abstract base class.""" @@ -124,9 +133,9 @@ def set_system_prompt(self, system_prompt: str) -> None: llm = MinimalLLM(name="Minimal") - # Should raise AttributeError (or RecursionError due to hasattr in __getattr__) - # The current implementation has a recursion issue, but it still raises an error - with pytest.raises((AttributeError, RecursionError)): + # Should raise RecursionError when self.llm doesn't exist + # because hasattr(self, "llm") in __getattr__ calls __getattr__ again + with pytest.raises(RecursionError): _ = llm.some_attribute def test_getattr_with_none_llm(self): @@ -431,3 +440,460 @@ def validator(response_obj): assert result.text == "Valid response" assert call_count == 1 # Should not retry for valid response + + def test_extract_retry_after_from_headers(self): + """Test extracting Retry-After header from exception.response.headers.""" + llm = ConcreteLLM(name="TestLLM") + + class MockHeaders: + def __init__(self, retry_after): + self._headers = {"Retry-After": str(retry_after)} + + def get(self, key): + return self._headers.get(key) or self._headers.get(key.lower()) + + class MockResponse: + def __init__(self, retry_after): + self.headers = MockHeaders(retry_after) + + class ExceptionWithRetryAfter(Exception): + def __init__(self, retry_after): + self.response = MockResponse(retry_after) + super().__init__(f"Rate limited, retry after {retry_after}") + + exc = ExceptionWithRetryAfter(30) + assert llm._extract_retry_after(exc) == 30 + + def test_extract_retry_after_case_insensitive(self): + """Test that Retry-After header extraction is case-insensitive.""" + llm = ConcreteLLM(name="TestLLM") + + class MockHeaders: + def __init__(self, retry_after): + self._headers = {"retry-after": str(retry_after)} + + def get(self, key): + return self._headers.get(key) or self._headers.get(key.lower()) + + class MockResponse: + def __init__(self, retry_after): + self.headers = MockHeaders(retry_after) + + class ExceptionWithRetryAfter(Exception): + def __init__(self, retry_after): + self.response = MockResponse(retry_after) + super().__init__("Rate limited") + + exc = ExceptionWithRetryAfter(45) + assert llm._extract_retry_after(exc) == 45 + + def test_extract_retry_after_no_headers(self): + """Test that None is returned when headers don't exist.""" + llm = ConcreteLLM(name="TestLLM") + + class ExceptionWithoutHeaders(Exception): + def __init__(self): + self.response = object() # No headers attribute + super().__init__("Error") + + exc = ExceptionWithoutHeaders() + assert llm._extract_retry_after(exc) is None + + def test_extract_retry_after_no_response(self): + """Test that None is returned when response doesn't exist.""" + llm = ConcreteLLM(name="TestLLM") + + exc = Exception("Generic error") + assert llm._extract_retry_after(exc) is None + + def test_extract_retry_after_invalid_value(self): + """Test that None is returned when Retry-After value is invalid.""" + llm = ConcreteLLM(name="TestLLM") + + class MockHeaders: + def get(self, key): + return "invalid" # Not a number + + class MockResponse: + def __init__(self): + self.headers = MockHeaders() + + class ExceptionWithInvalidRetryAfter(Exception): + def __init__(self): + self.response = MockResponse() + super().__init__("Error") + + exc = ExceptionWithInvalidRetryAfter() + assert llm._extract_retry_after(exc) is None + + @pytest.mark.asyncio + async def test_retry_with_backoff_429_rate_limit(self): + """Test retry logic for 429 (Too Many Requests) status code.""" + llm = ConcreteLLM(name="TestLLM", max_retries=3) + + call_count = 0 + + async def func_with_429_then_success(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ExceptionWithStatusCode(429, "HTTP status 429: Too Many Requests") + return "Success after retry" + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await llm._retry_with_backoff( + func_with_429_then_success, operation_name="test_operation" + ) + + assert result == "Success after retry" + assert call_count == 2 + # Should sleep once with exponential backoff (2^0 = 1 second) + assert mock_sleep.call_count == 1 + mock_sleep.assert_called_with(1) + + @pytest.mark.asyncio + async def test_retry_with_backoff_429_with_retry_after_header(self): + """Test that 429 respects Retry-After header.""" + from unittest.mock import AsyncMock, patch + + llm = ConcreteLLM(name="TestLLM", max_retries=3) + + call_count = 0 + + class MockHeaders: + def get(self, key): + return "15" if key.lower() == "retry-after" else None + + class MockResponse: + def __init__(self): + self.headers = MockHeaders() + + async def func_with_429_and_retry_after(): + nonlocal call_count + call_count += 1 + if call_count == 1: + exc = ExceptionWithStatusCode(429, "HTTP status 429: Too Many Requests") + exc.response = MockResponse() + raise exc + return "Success after retry" + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await llm._retry_with_backoff( + func_with_429_and_retry_after, + operation_name="test_operation", + ) + + assert result == "Success after retry" + assert call_count == 2 + # Should use Retry-After header value (15s) instead of backoff + assert mock_sleep.call_count == 1 + mock_sleep.assert_called_with(15) + + @pytest.mark.asyncio + async def test_retry_with_backoff_500_internal_server_error(self): + """Test retry logic for 500 (Internal Server Error) status code.""" + llm = ConcreteLLM(name="TestLLM", max_retries=5) + + call_count = 0 + + async def func_with_500_then_success(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise ExceptionWithStatusCode( + 500, "HTTP status 500: Internal Server Error" + ) + return "Success after retries" + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await llm._retry_with_backoff( + func_with_500_then_success, operation_name="test_operation" + ) + + assert result == "Success after retries" + assert call_count == 3 + # Should sleep twice with exponential backoff (2^0=1, 2^1=2) + assert mock_sleep.call_count == 2 + assert mock_sleep.call_args_list[0][0][0] == 1 + assert mock_sleep.call_args_list[1][0][0] == 2 + + @pytest.mark.asyncio + async def test_retry_with_backoff_500_limited_to_3_retries(self): + """Test that 500 status code is limited to 3 retries.""" + llm = ConcreteLLM(name="TestLLM", max_retries=10) + + call_count = 0 + + async def func_always_500(): + nonlocal call_count + call_count += 1 + raise ExceptionWithStatusCode(500, "HTTP status 500: Internal Server Error") + + with pytest.raises(RuntimeError) as exc_info: + await llm._retry_with_backoff( + func_always_500, operation_name="test_operation" + ) + + assert "after 3 retries" in str(exc_info.value) + assert call_count == 3 # Limited to 3 retries for 500 + + @pytest.mark.asyncio + async def test_retry_with_backoff_502_bad_gateway(self): + """Test retry logic for 502 (Bad Gateway) status code.""" + from unittest.mock import AsyncMock, patch + + llm = ConcreteLLM(name="TestLLM", max_retries=5) + + call_count = 0 + + async def func_with_502_then_success(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ExceptionWithStatusCode(502, "HTTP status 502: Bad Gateway") + return "Success after retry" + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await llm._retry_with_backoff( + func_with_502_then_success, operation_name="test_operation" + ) + + assert result == "Success after retry" + assert call_count == 2 + assert mock_sleep.call_count == 1 + + @pytest.mark.asyncio + async def test_retry_with_backoff_502_limited_to_3_retries(self): + """Test that 502 status code is limited to 3 retries.""" + llm = ConcreteLLM(name="TestLLM", max_retries=10) + + call_count = 0 + + async def func_always_502(): + nonlocal call_count + call_count += 1 + raise ExceptionWithStatusCode(502, "HTTP status 502: Bad Gateway") + + with pytest.raises(RuntimeError) as exc_info: + await llm._retry_with_backoff( + func_always_502, operation_name="test_operation" + ) + + assert "after 3 retries" in str(exc_info.value) + assert call_count == 3 # Limited to 3 retries for 502 + + @pytest.mark.asyncio + async def test_retry_with_backoff_503_service_unavailable(self): + """Test retry logic for 503 (Service Unavailable) status code.""" + from unittest.mock import AsyncMock, patch + + llm = ConcreteLLM(name="TestLLM", max_retries=4) + + call_count = 0 + + async def func_with_503_then_success(): + nonlocal call_count + call_count += 1 + if call_count <= 2: + raise ExceptionWithStatusCode( + 503, "HTTP status 503: Service Unavailable" + ) + return "Success after retries" + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await llm._retry_with_backoff( + func_with_503_then_success, operation_name="test_operation" + ) + + assert result == "Success after retries" + assert call_count == 3 + # Should sleep twice with exponential backoff + assert mock_sleep.call_count == 2 + + @pytest.mark.asyncio + async def test_retry_with_backoff_504_gateway_timeout(self): + """Test retry logic for 504 (Gateway Timeout) status code.""" + from unittest.mock import AsyncMock, patch + + llm = ConcreteLLM(name="TestLLM", max_retries=3) + + call_count = 0 + + async def func_with_504_then_success(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ExceptionWithStatusCode(504, "HTTP status 504: Gateway Timeout") + return "Success after retry" + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await llm._retry_with_backoff( + func_with_504_then_success, operation_name="test_operation" + ) + + assert result == "Success after retry" + assert call_count == 2 + assert mock_sleep.call_count == 1 + + @pytest.mark.asyncio + async def test_retry_with_backoff_529_overloaded(self): + """Test retry logic for 529 (Overloaded - Anthropic) status code.""" + from unittest.mock import AsyncMock, patch + + llm = ConcreteLLM(name="TestLLM", max_retries=3) + + call_count = 0 + + async def func_with_529_then_success(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ExceptionWithStatusCode(529, "HTTP status 529: Overloaded") + return "Success after retry" + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await llm._retry_with_backoff( + func_with_529_then_success, operation_name="test_operation" + ) + + assert result == "Success after retry" + assert call_count == 2 + assert mock_sleep.call_count == 1 + + @pytest.mark.asyncio + async def test_retry_with_backoff_exponential_backoff_timing(self): + """Test that exponential backoff timing is correct (2^attempt, max 60s).""" + from unittest.mock import AsyncMock, patch + + llm = ConcreteLLM(name="TestLLM", max_retries=5) + + call_count = 0 + + async def func_with_multiple_503(): + nonlocal call_count + call_count += 1 + if call_count <= 3: + raise ExceptionWithStatusCode( + 503, "HTTP status 503: Service Unavailable" + ) + return "Success" + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await llm._retry_with_backoff( + func_with_multiple_503, operation_name="test_operation" + ) + + assert result == "Success" + assert call_count == 4 + # Should sleep 3 times with exponential backoff: 2^0=1, 2^1=2, 2^2=4 + assert mock_sleep.call_count == 3 + assert mock_sleep.call_args_list[0][0][0] == 1 + assert mock_sleep.call_args_list[1][0][0] == 2 + assert mock_sleep.call_args_list[2][0][0] == 4 + + @pytest.mark.asyncio + async def test_retry_with_backoff_exponential_backoff_capped_at_60(self): + """Test that exponential backoff is capped at 60 seconds.""" + from unittest.mock import AsyncMock, patch + + llm = ConcreteLLM(name="TestLLM", max_retries=10) + + call_count = 0 + + async def func_with_many_503(): + nonlocal call_count + call_count += 1 + if call_count <= 7: # Need 7 attempts to reach 2^6 = 64 > 60 + raise ExceptionWithStatusCode( + 503, "HTTP status 503: Service Unavailable" + ) + return "Success" + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await llm._retry_with_backoff( + func_with_many_503, operation_name="test_operation" + ) + + assert result == "Success" + # Check that wait times are capped at 60 + wait_times = [call[0][0] for call in mock_sleep.call_args_list] + assert all(wait <= 60 for wait in wait_times) + # At attempt 6, 2^6 = 64, should be capped to 60 + assert 60 in wait_times + + @pytest.mark.asyncio + async def test_retry_with_backoff_non_retryable_status_code(self): + """Test that non-retryable status codes raise immediately.""" + llm = ConcreteLLM(name="TestLLM", max_retries=3) + + async def func_with_400(): + raise ExceptionWithStatusCode(400, "HTTP status 400: Bad Request") + + with pytest.raises(RuntimeError) as exc_info: + await llm._retry_with_backoff( + func_with_400, operation_name="test_operation" + ) + + assert "Error in test_operation" in str(exc_info.value) + assert "400" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_retry_with_backoff_retryable_keyword_in_message(self): + """Test retryable keywords are retried even without status code.""" + llm = ConcreteLLM(name="TestLLM", max_retries=3) + + call_count = 0 + + async def func_with_retryable_message(): + nonlocal call_count + call_count += 1 + if call_count == 1: + # No status code, but has retryable keyword + raise Exception("Rate limit exceeded") + return "Success" + + with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep: + result = await llm._retry_with_backoff( + func_with_retryable_message, operation_name="test_operation" + ) + + assert result == "Success" + assert call_count == 2 + # Should treat as 503 and retry + assert mock_sleep.call_count == 1 + + @pytest.mark.asyncio + async def test_retry_with_backoff_non_retryable_error_message(self): + """Test that errors without retryable keywords raise immediately.""" + llm = ConcreteLLM(name="TestLLM", max_retries=3) + + async def func_with_non_retryable_error(): + raise Exception("Invalid API key provided") + + with pytest.raises(RuntimeError) as exc_info: + await llm._retry_with_backoff( + func_with_non_retryable_error, operation_name="test_operation" + ) + + assert "Error in test_operation" in str(exc_info.value) + assert "Invalid API key" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_retry_with_backoff_max_retries_exceeded(self): + """Test that RuntimeError is raised when max retries are exceeded.""" + llm = ConcreteLLM(name="TestLLM", max_retries=2) + + call_count = 0 + + async def func_always_503(): + nonlocal call_count + call_count += 1 + raise ExceptionWithStatusCode(503, "HTTP status 503: Service Unavailable") + + with pytest.raises(RuntimeError) as exc_info: + await llm._retry_with_backoff( + func_always_503, operation_name="test_operation" + ) + + assert "after 2 retries" in str(exc_info.value) + assert call_count == 2 # max_retries attempts