From 67caf9bed84c17939ee9b33f0ed5c2ff01b2ba76 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Sat, 29 Nov 2025 01:56:29 -0500 Subject: [PATCH] loop - restart connection - gemini --- src/strands/experimental/bidi/agent/loop.py | 1 + .../experimental/bidi/models/bidi_model.py | 11 ++++- .../experimental/bidi/models/gemini_live.py | 25 ++++++++--- .../experimental/bidi/models/novasonic.py | 6 +-- .../experimental/bidi/agent/test_loop.py | 5 ++- .../bidi/models/test_gemini_live.py | 45 +++++++++++++++++++ 6 files changed, 82 insertions(+), 11 deletions(-) diff --git a/src/strands/experimental/bidi/agent/loop.py b/src/strands/experimental/bidi/agent/loop.py index 80245d9b2..13b7033a4 100644 --- a/src/strands/experimental/bidi/agent/loop.py +++ b/src/strands/experimental/bidi/agent/loop.py @@ -198,6 +198,7 @@ async def _restart_connection(self, timeout_error: BidiModelTimeoutError) -> Non self._agent.system_prompt, self._agent.tool_registry.get_all_tool_specs(), self._agent.messages, + **timeout_error.restart_config, ) self._task_pool.create(self._run_model()) except Exception as exception: diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py index bc2806e78..0d0da63d2 100644 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ b/src/strands/experimental/bidi/models/bidi_model.py @@ -118,4 +118,13 @@ class BidiModelTimeoutError(Exception): to create a seamless, uninterrupted experience for the user. """ - pass + def __init__(self, message: str, **restart_config: Any) -> None: + """Initialize error. + + Args: + message: Timeout message from model. + **restart_config: Configure restart specific behaviors in the call to model start. + """ + super().__init__(self, message) + + self.restart_config = restart_config diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 2e9a13b54..1f2b2d5cd 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -40,7 +40,7 @@ BidiUsageEvent, ModalityUsage, ) -from .bidi_model import BidiModel +from .bidi_model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) @@ -92,6 +92,7 @@ def __init__( # Connection state (initialized in start()) self._live_session: Any = None self._live_session_context_manager: Any = None + self._live_session_handle: str | None = None self._connection_id: str | None = None def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: @@ -175,8 +176,8 @@ async def start( ) self._live_session = await self._live_session_context_manager.__aenter__() - # Send initial message history if provided - if messages: + # Gemini itself restores message history when resuming from session + if messages and "live_session_handle" not in kwargs: await self._send_message_history(messages) async def _send_message_history(self, messages: Messages) -> None: @@ -227,7 +228,22 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut Returns: List of event dicts (empty list if no events to emit). + + Raises: + BidiModelTimeoutError: If gemini responds with go away message. """ + if message.go_away: + raise BidiModelTimeoutError( + message.go_away.model_dump_json(), live_session_handle=self._live_session_handle + ) + + if message.session_resumption_update: + resumption_update = message.session_resumption_update + if resumption_update.resumable and resumption_update.new_handle: + self._live_session_handle = resumption_update.new_handle + logger.debug("session_handle=<%s> | updating gemini session handle", self._live_session_handle) + return [] + # Handle interruption first (from server_content) if message.server_content and message.server_content.interrupted: return [BidiInterruptionEvent(reason="user_speech")] @@ -491,8 +507,7 @@ def _build_live_config( if self.config: config_dict.update({k: v for k, v in self.config.items() if k != "audio"}) - # Override with any kwargs from start() - config_dict.update(kwargs) + config_dict["session_resumption"] = {"handle": kwargs.get("live_session_handle")} # Add system instruction if provided if system_prompt: diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py index 24c932ab0..713afe028 100644 --- a/src/strands/experimental/bidi/models/novasonic.py +++ b/src/strands/experimental/bidi/models/novasonic.py @@ -297,13 +297,13 @@ async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: event_data = await output.receive() except ValidationException as error: - if "InternalErrorCode=531" in str(error): + if "InternalErrorCode=531" in error.message: # nova also times out if user is silent for 175 seconds - raise BidiModelTimeoutError(error) from error + raise BidiModelTimeoutError(error.message) from error raise except ModelTimeoutException as error: - raise BidiModelTimeoutError(error) from error + raise BidiModelTimeoutError(error.message) from error if not event_data: continue diff --git a/tests/strands/experimental/bidi/agent/test_loop.py b/tests/strands/experimental/bidi/agent/test_loop.py index 68346ab19..d19cada60 100644 --- a/tests/strands/experimental/bidi/agent/test_loop.py +++ b/tests/strands/experimental/bidi/agent/test_loop.py @@ -42,7 +42,7 @@ async def loop(agent): @pytest.mark.asyncio async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerator): - timeout_error = BidiModelTimeoutError("test timeout") + timeout_error = BidiModelTimeoutError("test timeout", test_restart_config=1) text_event = BidiTextInputEvent(text="test after restart") agent.model.receive = unittest.mock.Mock(side_effect=[timeout_error, agenerator([text_event])]) @@ -63,10 +63,11 @@ async def test_bidi_agent_loop_receive_restart_connection(loop, agent, agenerato agent.model.stop.assert_called_once() assert agent.model.start.call_count == 2 - agent.model.start.assert_any_call( + agent.model.start.assert_called_with( agent.system_prompt, agent.tool_registry.get_all_tool_specs(), agent.messages, + test_restart_config=1, ) diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index dec83dbe3..a880bb223 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -13,6 +13,7 @@ import pytest from google.genai import types as genai_types +from strands.experimental.bidi.models.bidi_model import BidiModelTimeoutError from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, @@ -279,6 +280,34 @@ async def test_receive_lifecycle_events(mock_genai_client, model, agenerator): assert event.connection_id == model._connection_id +@pytest.mark.asyncio +async def test_receive_timeout(mock_genai_client, model, agenerator): + mock_resumption_response = unittest.mock.Mock() + mock_resumption_response.go_away = None + mock_resumption_response.session_resumption_update = unittest.mock.Mock() + mock_resumption_response.session_resumption_update.resumable = True + mock_resumption_response.session_resumption_update.new_handle = "h1" + + mock_timeout_response = unittest.mock.Mock() + mock_timeout_response.go_away = unittest.mock.Mock() + mock_timeout_response.go_away.model_dump_json.return_value = "test timeout" + + _, mock_live_session, _ = mock_genai_client + mock_live_session.receive = unittest.mock.Mock( + return_value=agenerator([mock_resumption_response, mock_timeout_response]) + ) + + await model.start() + + with pytest.raises(BidiModelTimeoutError, match=r"test timeout"): + async for _ in model.receive(): + pass + + tru_handle = model._live_session_handle + exp_handle = "h1" + assert tru_handle == exp_handle + + @pytest.mark.asyncio async def test_event_conversion(mock_genai_client, model): """Test conversion of all Gemini Live event types to standard format.""" @@ -288,6 +317,8 @@ async def test_event_conversion(mock_genai_client, model): # Test text output (converted to transcript via model_turn.parts) mock_text = unittest.mock.Mock() mock_text.data = None + mock_text.go_away = None + mock_text.session_resumption_update = None mock_text.tool_call = None # Create proper server_content structure with model_turn @@ -319,6 +350,8 @@ async def test_event_conversion(mock_genai_client, model): # Test multiple text parts (should concatenate) mock_multi_text = unittest.mock.Mock() mock_multi_text.data = None + mock_multi_text.go_away = None + mock_multi_text.session_resumption_update = None mock_multi_text.tool_call = None mock_server_content_multi = unittest.mock.Mock() @@ -347,6 +380,8 @@ async def test_event_conversion(mock_genai_client, model): mock_audio = unittest.mock.Mock() mock_audio.text = None mock_audio.data = b"audio_data" + mock_audio.go_away = None + mock_audio.session_resumption_update = None mock_audio.tool_call = None mock_audio.server_content = None @@ -373,6 +408,8 @@ async def test_event_conversion(mock_genai_client, model): mock_tool = unittest.mock.Mock() mock_tool.text = None mock_tool.data = None + mock_tool.go_away = None + mock_tool.session_resumption_update = None mock_tool.tool_call = mock_tool_call mock_tool.server_content = None @@ -404,6 +441,8 @@ async def test_event_conversion(mock_genai_client, model): mock_tool_multi = unittest.mock.Mock() mock_tool_multi.text = None mock_tool_multi.data = None + mock_tool_multi.go_away = None + mock_tool_multi.session_resumption_update = None mock_tool_multi.tool_call = mock_tool_call_multi mock_tool_multi.server_content = None @@ -431,6 +470,8 @@ async def test_event_conversion(mock_genai_client, model): mock_interrupt = unittest.mock.Mock() mock_interrupt.text = None mock_interrupt.data = None + mock_interrupt.go_away = None + mock_interrupt.session_resumption_update = None mock_interrupt.tool_call = None mock_interrupt.server_content = mock_server_content @@ -549,6 +590,8 @@ async def test_custom_audio_rates_in_events(mock_genai_client, model_id, api_key mock_audio = unittest.mock.Mock() mock_audio.text = None mock_audio.data = b"audio_data" + mock_audio.go_away = None + mock_audio.session_resumption_update = None mock_audio.tool_call = None mock_audio.server_content = None @@ -577,6 +620,8 @@ async def test_default_audio_rates_in_events(mock_genai_client, model_id, api_ke mock_audio = unittest.mock.Mock() mock_audio.text = None mock_audio.data = b"audio_data" + mock_audio.go_away = None + mock_audio.session_resumption_update = None mock_audio.tool_call = None mock_audio.server_content = None