From 642752abb690a79220b8ecc4a24f7ad371ac9bf8 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 16:08:33 -0500 Subject: [PATCH 1/6] rename modules --- src/strands/experimental/bidi/__init__.py | 6 +- src/strands/experimental/bidi/agent/agent.py | 4 +- .../experimental/bidi/models/__init__.py | 6 +- .../experimental/bidi/models/bidi_model.py | 130 --- .../experimental/bidi/models/gemini_live.py | 2 +- .../experimental/bidi/models/novasonic.py | 760 ---------------- .../experimental/bidi/models/openai.py | 816 ------------------ .../experimental/bidi/types/bidi_model.py | 36 - src/strands/experimental/bidi/types/events.py | 2 +- .../bidi/models/test_gemini_live.py | 2 +- .../bidi/models/test_novasonic.py | 4 +- .../experimental/bidi/models/test_openai.py | 4 +- 12 files changed, 15 insertions(+), 1757 deletions(-) delete mode 100644 src/strands/experimental/bidi/models/bidi_model.py delete mode 100644 src/strands/experimental/bidi/models/novasonic.py delete mode 100644 src/strands/experimental/bidi/models/openai.py delete mode 100644 src/strands/experimental/bidi/types/bidi_model.py diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 7e2ad2bb3..13c5b51e1 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -18,12 +18,12 @@ from .io.audio import BidiAudioIO # Model interface (for custom implementations) -from .models.bidi_model import BidiModel +from .models.model import BidiModel # Model providers - What users need to create models from .models.gemini_live import BidiGeminiLiveModel -from .models.novasonic import BidiNovaSonicModel -from .models.openai import BidiOpenAIRealtimeModel +from .models.nova_sonic import BidiNovaSonicModel +from .models.openai_realtime import BidiOpenAIRealtimeModel # Built-in tools from .tools import stop_conversation diff --git a/src/strands/experimental/bidi/agent/agent.py b/src/strands/experimental/bidi/agent/agent.py index 74b65ba10..68075d0b2 100644 --- a/src/strands/experimental/bidi/agent/agent.py +++ b/src/strands/experimental/bidi/agent/agent.py @@ -30,8 +30,8 @@ from ...hooks.events import BidiAgentInitializedEvent from ...tools import ToolProvider from .._async import stop_all -from ..models.bidi_model import BidiModel -from ..models.novasonic import BidiNovaSonicModel +from ..models.model import BidiModel +from ..models.nova_sonic import BidiNovaSonicModel from ..types.agent import BidiAgentInput from ..types.events import ( BidiAudioInputEvent, diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index d1221df36..29a2229c5 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,9 +1,9 @@ """Bidirectional model interfaces and implementations.""" -from .bidi_model import BidiModel, BidiModelTimeoutError +from .model import BidiModel, BidiModelTimeoutError from .gemini_live import BidiGeminiLiveModel -from .novasonic import BidiNovaSonicModel -from .openai import BidiOpenAIRealtimeModel +from .nova_sonic import BidiNovaSonicModel +from .openai_realtime import BidiOpenAIRealtimeModel __all__ = [ "BidiModel", diff --git a/src/strands/experimental/bidi/models/bidi_model.py b/src/strands/experimental/bidi/models/bidi_model.py deleted file mode 100644 index 0d0da63d2..000000000 --- a/src/strands/experimental/bidi/models/bidi_model.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Bidirectional streaming model interface. - -Defines the abstract interface for models that support real-time bidirectional -communication with persistent connections. Unlike traditional request-response -models, bidirectional models maintain an open connection for streaming audio, -text, and tool interactions. - -Features: -- Persistent connection management with connect/close lifecycle -- Real-time bidirectional communication (send and receive simultaneously) -- Provider-agnostic event normalization -- Support for audio, text, image, and tool result streaming -""" - -import logging -from typing import Any, AsyncIterable, Protocol - -from ....types._events import ToolResultEvent -from ....types.content import Messages -from ....types.tools import ToolSpec -from ..types.events import ( - BidiInputEvent, - BidiOutputEvent, -) - -logger = logging.getLogger(__name__) - - -class BidiModel(Protocol): - """Protocol for bidirectional streaming models. - - This interface defines the contract for models that support persistent streaming - connections with real-time audio and text communication. Implementations handle - provider-specific protocols while exposing a standardized event-based API. - - Attributes: - config: Configuration dictionary with provider-specific settings. - """ - - config: dict[str, Any] - - async def start( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs: Any, - ) -> None: - """Establish a persistent streaming connection with the model. - - Opens a bidirectional connection that remains active for real-time communication. - The connection supports concurrent sending and receiving of events until explicitly - closed. Must be called before any send() or receive() operations. - - Args: - system_prompt: System instructions to configure model behavior. - tools: Tool specifications that the model can invoke during the conversation. - messages: Initial conversation history to provide context. - **kwargs: Provider-specific configuration options. - """ - ... - - async def stop(self) -> None: - """Close the streaming connection and release resources. - - Terminates the active bidirectional connection and cleans up any associated - resources such as network connections, buffers, or background tasks. After - calling close(), the model instance cannot be used until start() is called again. - """ - ... - - def receive(self) -> AsyncIterable[BidiOutputEvent]: - """Receive streaming events from the model. - - Continuously yields events from the model as they arrive over the connection. - Events are normalized to a provider-agnostic format for uniform processing. - This method should be called in a loop or async task to process model responses. - - The stream continues until the connection is closed or an error occurs. - - Yields: - BidiOutputEvent: Standardized event objects containing audio output, - transcripts, tool calls, or control signals. - """ - ... - - async def send( - self, - content: BidiInputEvent | ToolResultEvent, - ) -> None: - """Send content to the model over the active connection. - - Transmits user input or tool results to the model during an active streaming - session. Supports multiple content types including text, audio, images, and - tool execution results. Can be called multiple times during a conversation. - - Args: - content: The content to send. Must be one of: - - BidiTextInputEvent: Text message from the user - - BidiAudioInputEvent: Audio data for speech input - - BidiImageInputEvent: Image data for visual understanding - - ToolResultEvent: Result from a tool execution - - Example: - await model.send(BidiTextInputEvent(text="Hello", role="user")) - await model.send(BidiAudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) - await model.send(BidiImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) - await model.send(ToolResultEvent(tool_result)) - """ - ... - - -class BidiModelTimeoutError(Exception): - """Model timeout error. - - Bidirectional models are often configured with a connection time limit. Nova sonic for example keeps the connection - open for 8 minutes max. Upon receiving a timeout, the agent loop is configured to restart the model connection so as - to create a seamless, uninterrupted experience for the user. - """ - - def __init__(self, message: str, **restart_config: Any) -> None: - """Initialize error. - - Args: - message: Timeout message from model. - **restart_config: Configure restart specific behaviors in the call to model start. - """ - super().__init__(self, message) - - self.restart_config = restart_config diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 1f2b2d5cd..efc1d1832 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -40,7 +40,7 @@ BidiUsageEvent, ModalityUsage, ) -from .bidi_model import BidiModel, BidiModelTimeoutError +from .model import BidiModel, BidiModelTimeoutError logger = logging.getLogger(__name__) diff --git a/src/strands/experimental/bidi/models/novasonic.py b/src/strands/experimental/bidi/models/novasonic.py deleted file mode 100644 index 713afe028..000000000 --- a/src/strands/experimental/bidi/models/novasonic.py +++ /dev/null @@ -1,760 +0,0 @@ -"""Nova Sonic bidirectional model provider for real-time streaming conversations. - -Implements the BidiModel interface for Amazon's Nova Sonic, handling the -complex event sequencing and audio processing required by Nova Sonic's -InvokeModelWithBidirectionalStream protocol. - -Nova Sonic specifics: -- Hierarchical event sequences: connectionStart → promptStart → content streaming -- Base64-encoded audio format with hex encoding -- Tool execution with content containers and identifier tracking -- 8-minute connection limits with proper cleanup sequences -- Interruption detection through stopReason events -""" - -import asyncio -import base64 -import json -import logging -import uuid -from typing import Any, AsyncGenerator, cast - -import boto3 -from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput -from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme -from aws_sdk_bedrock_runtime.models import ( - BidirectionalInputPayloadPart, - InvokeModelWithBidirectionalStreamInputChunk, - ModelTimeoutException, - ValidationException, -) -from smithy_aws_core.identity.static import StaticCredentialsResolver -from smithy_core.aio.eventstream import DuplexEventStream -from smithy_core.shapes import ShapeID - -from ....types._events import ToolResultEvent, ToolUseStreamEvent -from ....types.content import Messages -from ....types.tools import ToolResult, ToolSpec, ToolUse -from .._async import stop_all -from ..types.bidi_model import AudioConfig -from ..types.events import ( - AudioChannel, - AudioSampleRate, - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiConnectionStartEvent, - BidiInputEvent, - BidiInterruptionEvent, - BidiOutputEvent, - BidiResponseCompleteEvent, - BidiResponseStartEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, - BidiUsageEvent, -) -from .bidi_model import BidiModel, BidiModelTimeoutError - -logger = logging.getLogger(__name__) - -# Nova Sonic configuration constants -NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} - -NOVA_AUDIO_INPUT_CONFIG = { - "mediaType": "audio/lpcm", - "sampleRateHertz": 16000, - "sampleSizeBits": 16, - "channelCount": 1, - "audioType": "SPEECH", - "encoding": "base64", -} - -NOVA_AUDIO_OUTPUT_CONFIG = { - "mediaType": "audio/lpcm", - "sampleRateHertz": 16000, - "sampleSizeBits": 16, - "channelCount": 1, - "voiceId": "matthew", - "encoding": "base64", - "audioType": "SPEECH", -} - -NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} -NOVA_TOOL_CONFIG = {"mediaType": "application/json"} - - -class BidiNovaSonicModel(BidiModel): - """Nova Sonic implementation for bidirectional streaming. - - Combines model configuration and connection state in a single class. - Manages Nova Sonic's complex event sequencing, audio format conversion, and - tool execution patterns while providing the standard BidiModel interface. - - Attributes: - _stream: open bedrock stream to nova sonic. - """ - - _stream: DuplexEventStream - - def __init__( - self, - model_id: str = "amazon.nova-sonic-v1:0", - provider_config: dict[str, Any] | None = None, - client_config: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - """Initialize Nova Sonic bidirectional model. - - Args: - model_id: Model identifier (default: amazon.nova-sonic-v1:0) - provider_config: Model behavior (audio, inference settings) - client_config: AWS authentication (boto_session OR region, not both) - **kwargs: Reserved for future parameters. - """ - # Store model ID - self.model_id = model_id - - # Resolve client config with defaults - self._client_config = self._resolve_client_config(client_config or {}) - - # Resolve provider config with defaults - self.config = self._resolve_provider_config(provider_config or {}) - - # Store session and region for later use - self._session = self._client_config["boto_session"] - self.region = self._client_config["region"] - - # Track API-provided identifiers - self._connection_id: str | None = None - self._audio_content_name: str | None = None - self._current_completion_id: str | None = None - - # Indicates if model is done generating transcript - self._generation_stage: str | None = None - - # Ensure certain events are sent in sequence when required - self._send_lock = asyncio.Lock() - - logger.debug("model_id=<%s> | nova sonic model initialized", model_id) - - def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: - """Resolve AWS client config (creates boto session if needed).""" - if "boto_session" in config and "region" in config: - raise ValueError("Cannot specify both 'boto_session' and 'region' in client_config") - - resolved = config.copy() - - # Create boto session if not provided - if "boto_session" not in resolved: - resolved["boto_session"] = boto3.Session() - - # Resolve region from session or use default - if "region" not in resolved: - resolved["region"] = resolved["boto_session"].region_name or "us-east-1" - - return resolved - - def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: - """Merge user config with defaults (user takes precedence).""" - # Define default audio configuration - default_audio_config: AudioConfig = { - "input_rate": cast(AudioSampleRate, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), - "output_rate": cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), - "channels": cast(AudioChannel, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), - "format": "pcm", - "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), - } - - user_audio_config = config.get("audio", {}) - merged_audio = {**default_audio_config, **user_audio_config} - - resolved = { - "audio": merged_audio, - **{k: v for k, v in config.items() if k != "audio"}, - } - - if user_audio_config: - logger.debug("audio_config | merged user-provided config with defaults") - else: - logger.debug("audio_config | using default Nova Sonic audio configuration") - - return resolved - - async def start( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs: Any, - ) -> None: - """Establish bidirectional connection to Nova Sonic. - - Args: - system_prompt: System instructions for the model. - tools: List of tools available to the model. - messages: Conversation history to initialize with. - **kwargs: Additional configuration options. - - Raises: - RuntimeError: If user calls start again without first stopping. - """ - if self._connection_id: - raise RuntimeError("model already started | call stop before starting again") - - logger.debug("nova connection starting") - - self._connection_id = str(uuid.uuid4()) - - # Get credentials from boto3 session (full credential chain) - credentials = self._session.get_credentials() - - if not credentials: - raise ValueError( - "no AWS credentials found. configure credentials via environment variables, " - "credential files, IAM roles, or SSO." - ) - - # Use static resolver with credentials configured as properties - resolver = StaticCredentialsResolver() - - config = Config( - endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", - region=self.region, - aws_credentials_identity_resolver=resolver, - auth_scheme_resolver=HTTPAuthSchemeResolver(), - auth_schemes={ShapeID("aws.auth#sigv4"): SigV4AuthScheme(service="bedrock")}, - # Configure static credentials as properties - aws_access_key_id=credentials.access_key, - aws_secret_access_key=credentials.secret_key, - aws_session_token=credentials.token, - ) - - self.client = BedrockRuntimeClient(config=config) - logger.debug("region=<%s> | nova sonic client initialized", self.region) - - client = BedrockRuntimeClient(config=config) - self._stream = await client.invoke_model_with_bidirectional_stream( - InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) - ) - logger.debug("region=<%s> | nova sonic client initialized", self.region) - - init_events = self._build_initialization_events(system_prompt, tools, messages) - logger.debug("event_count=<%d> | sending nova sonic initialization events", len(init_events)) - await self._send_nova_events(init_events) - - logger.info("connection_id=<%s> | nova sonic connection established", self._connection_id) - - def _build_initialization_events( - self, system_prompt: str | None, tools: list[ToolSpec] | None, messages: Messages | None - ) -> list[str]: - """Build the sequence of initialization events.""" - tools = tools or [] - events = [ - self._get_connection_start_event(), - self._get_prompt_start_event(tools), - *self._get_system_prompt_events(system_prompt), - ] - - # Add conversation history if provided - if messages: - events.extend(self._get_message_history_events(messages)) - logger.debug("message_count=<%d> | conversation history added to initialization", len(messages)) - - return events - - def _log_event_type(self, nova_event: dict[str, Any]) -> None: - """Log specific Nova Sonic event types for debugging.""" - if "usageEvent" in nova_event: - logger.debug("usage=<%s> | nova usage event received", nova_event["usageEvent"]) - elif "textOutput" in nova_event: - logger.debug("nova text output received") - elif "toolUse" in nova_event: - tool_use = nova_event["toolUse"] - logger.debug( - "tool_name=<%s>, tool_use_id=<%s> | nova tool use received", - tool_use["toolName"], - tool_use["toolUseId"], - ) - elif "audioOutput" in nova_event: - audio_content = nova_event["audioOutput"]["content"] - audio_bytes = base64.b64decode(audio_content) - logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) - - async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: - """Receive Nova Sonic events and convert to provider-agnostic format. - - Raises: - RuntimeError: If start has not been called. - """ - if not self._connection_id: - raise RuntimeError("model not started | call start before receiving") - - logger.debug("nova event stream starting") - yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) - - _, output = await self._stream.await_output() - while True: - try: - event_data = await output.receive() - - except ValidationException as error: - if "InternalErrorCode=531" in error.message: - # nova also times out if user is silent for 175 seconds - raise BidiModelTimeoutError(error.message) from error - raise - - except ModelTimeoutException as error: - raise BidiModelTimeoutError(error.message) from error - - if not event_data: - continue - - nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] - self._log_event_type(nova_event) - - model_event = self._convert_nova_event(nova_event) - if model_event: - yield model_event - - async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: - """Unified send method for all content types. Sends the given content to Nova Sonic. - - Dispatches to appropriate internal handler based on content type. - - Args: - content: Input event. - - Raises: - ValueError: If content type not supported (e.g., image content). - """ - if not self._connection_id: - raise RuntimeError("model not started | call start before sending") - - if isinstance(content, BidiTextInputEvent): - await self._send_text_content(content.text) - elif isinstance(content, BidiAudioInputEvent): - await self._send_audio_content(content) - elif isinstance(content, ToolResultEvent): - tool_result = content.get("tool_result") - if tool_result: - await self._send_tool_result(tool_result) - else: - raise ValueError(f"content_type={type(content)} | content not supported") - - async def _start_audio_connection(self) -> None: - """Internal: Start audio input connection (call once before sending audio chunks).""" - logger.debug("nova audio connection starting") - self._audio_content_name = str(uuid.uuid4()) - - # Build audio input configuration from config - audio_input_config = { - "mediaType": "audio/lpcm", - "sampleRateHertz": self.config["audio"]["input_rate"], - "sampleSizeBits": 16, - "channelCount": self.config["audio"]["channels"], - "audioType": "SPEECH", - "encoding": "base64", - } - - audio_content_start = json.dumps( - { - "event": { - "contentStart": { - "promptName": self._connection_id, - "contentName": self._audio_content_name, - "type": "AUDIO", - "interactive": True, - "role": "USER", - "audioInputConfiguration": audio_input_config, - } - } - } - ) - - await self._send_nova_events([audio_content_start]) - - async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: - """Internal: Send audio using Nova Sonic protocol-specific format.""" - # Start audio connection if not already active - if not self._audio_content_name: - await self._start_audio_connection() - - # Audio is already base64 encoded in the event - # Send audio input event - audio_event = json.dumps( - { - "event": { - "audioInput": { - "promptName": self._connection_id, - "contentName": self._audio_content_name, - "content": audio_input.audio, - } - } - } - ) - - await self._send_nova_events([audio_event]) - - async def _end_audio_input(self) -> None: - """Internal: End current audio input connection to trigger Nova Sonic processing.""" - if not self._audio_content_name: - return - - logger.debug("nova audio connection ending") - - audio_content_end = json.dumps( - {"event": {"contentEnd": {"promptName": self._connection_id, "contentName": self._audio_content_name}}} - ) - - await self._send_nova_events([audio_content_end]) - self._audio_content_name = None - - async def _send_text_content(self, text: str) -> None: - """Internal: Send text content using Nova Sonic format.""" - content_name = str(uuid.uuid4()) - events = [ - self._get_text_content_start_event(content_name), - self._get_text_input_event(content_name, text), - self._get_content_end_event(content_name), - ] - await self._send_nova_events(events) - - async def _send_tool_result(self, tool_result: ToolResult) -> None: - """Internal: Send tool result using Nova Sonic toolResult format.""" - tool_use_id = tool_result["toolUseId"] - - logger.debug("tool_use_id=<%s> | sending nova tool result", tool_use_id) - - # Validate content types and preserve structure - content = tool_result.get("content", []) - - # Validate all content types are supported - for block in content: - if "text" not in block and "json" not in block: - # Unsupported content type - raise error - raise ValueError( - f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " - f"Content type not supported by Nova Sonic" - ) - - # Optimize for single content item - unwrap the array - if len(content) == 1: - result_data = cast(dict[str, Any], content[0]) - else: - # Multiple items - send as array - result_data = {"content": content} - - content_name = str(uuid.uuid4()) - events = [ - self._get_tool_content_start_event(content_name, tool_use_id), - self._get_tool_result_event(content_name, result_data), - self._get_content_end_event(content_name), - ] - await self._send_nova_events(events) - - async def stop(self) -> None: - """Close Nova Sonic connection with proper cleanup sequence.""" - logger.debug("nova connection cleanup starting") - - async def stop_events() -> None: - if not self._connection_id: - return - - await self._end_audio_input() - cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] - await self._send_nova_events(cleanup_events) - - async def stop_stream() -> None: - if not hasattr(self, "_stream"): - return - - await self._stream.close() - - async def stop_connection() -> None: - self._connection_id = None - - await stop_all(stop_events, stop_stream, stop_connection) - - logger.debug("nova connection closed") - - def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | None: - """Convert Nova Sonic events to TypedEvent format.""" - # Handle completion start - track completionId - if "completionStart" in nova_event: - completion_data = nova_event["completionStart"] - self._current_completion_id = completion_data.get("completionId") - logger.debug("completion_id=<%s> | nova completion started", self._current_completion_id) - return None - - # Handle completion end - if "completionEnd" in nova_event: - completion_data = nova_event["completionEnd"] - completion_id = completion_data.get("completionId", self._current_completion_id) - stop_reason = completion_data.get("stopReason", "END_TURN") - - event = BidiResponseCompleteEvent( - response_id=completion_id or str(uuid.uuid4()), # Fallback to UUID if missing - stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete", - ) - - # Clear completion tracking - self._current_completion_id = None - return event - - # Handle audio output - if "audioOutput" in nova_event: - # Audio is already base64 string from Nova Sonic - audio_content = nova_event["audioOutput"]["content"] - return BidiAudioStreamEvent( - audio=audio_content, - format="pcm", - sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]), - channels=cast(AudioChannel, self.config["audio"]["channels"]), - ) - - # Handle text output (transcripts) - elif "textOutput" in nova_event: - text_output = nova_event["textOutput"] - text_content = text_output["content"] - # Check for Nova Sonic interruption pattern - if '{ "interrupted" : true }' in text_content: - logger.debug("nova interruption detected in text output") - return BidiInterruptionEvent(reason="user_speech") - - return BidiTranscriptStreamEvent( - delta={"text": text_content}, - text=text_content, - role=text_output["role"].lower(), - is_final=self._generation_stage == "FINAL", - current_transcript=text_content, - ) - - # Handle tool use - if "toolUse" in nova_event: - tool_use = nova_event["toolUse"] - tool_use_event: ToolUse = { - "toolUseId": tool_use["toolUseId"], - "name": tool_use["toolName"], - "input": json.loads(tool_use["content"]), - } - # Return ToolUseStreamEvent - cast to dict for type compatibility - return ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) - - # Handle interruption - if nova_event.get("stopReason") == "INTERRUPTED": - logger.debug("nova interruption detected via stop reason") - return BidiInterruptionEvent(reason="user_speech") - - # Handle usage events - convert to multimodal usage format - if "usageEvent" in nova_event: - usage_data = nova_event["usageEvent"] - total_input = usage_data.get("totalInputTokens", 0) - total_output = usage_data.get("totalOutputTokens", 0) - - return BidiUsageEvent( - input_tokens=total_input, - output_tokens=total_output, - total_tokens=usage_data.get("totalTokens", total_input + total_output), - ) - - # Handle content start events (emit response start) - if "contentStart" in nova_event: - content_data = nova_event["contentStart"] - if content_data["type"] == "TEXT": - self._generation_stage = json.loads(content_data["additionalModelFields"])["generationStage"] - - # Emit response start event using API-provided completionId - # completionId should already be tracked from completionStart event - return BidiResponseStartEvent( - response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing - ) - - if "contentEnd" in nova_event: - self._generation_stage = None - - # Ignore all other events - return None - - def _get_connection_start_event(self) -> str: - """Generate Nova Sonic connection start event.""" - return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) - - def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: - """Generate Nova Sonic prompt start event with tool configuration.""" - # Build audio output configuration from config - audio_output_config = { - "mediaType": "audio/lpcm", - "sampleRateHertz": self.config["audio"]["output_rate"], - "sampleSizeBits": 16, - "channelCount": self.config["audio"]["channels"], - "voiceId": self.config["audio"].get("voice", "matthew"), - "encoding": "base64", - "audioType": "SPEECH", - } - - prompt_start_event: dict[str, Any] = { - "event": { - "promptStart": { - "promptName": self._connection_id, - "textOutputConfiguration": NOVA_TEXT_CONFIG, - "audioOutputConfiguration": audio_output_config, - } - } - } - - if tools: - tool_config = self._build_tool_configuration(tools) - prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG - prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} - - return json.dumps(prompt_start_event) - - def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict[str, Any]]: - """Build tool configuration from tool specs.""" - tool_config: list[dict[str, Any]] = [] - for tool in tools: - input_schema = ( - {"json": json.dumps(tool["inputSchema"]["json"])} - if "json" in tool["inputSchema"] - else {"json": json.dumps(tool["inputSchema"])} - ) - - tool_config.append( - {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} - ) - return tool_config - - def _get_system_prompt_events(self, system_prompt: str | None) -> list[str]: - """Generate system prompt events.""" - content_name = str(uuid.uuid4()) - return [ - self._get_text_content_start_event(content_name, "SYSTEM"), - self._get_text_input_event(content_name, system_prompt or ""), - self._get_content_end_event(content_name), - ] - - def _get_message_history_events(self, messages: Messages) -> list[str]: - """Generate conversation history events from agent messages. - - Converts agent message history to Nova Sonic format following the - contentStart/textInput/contentEnd pattern for each message. - - Args: - messages: List of conversation messages with role and content. - - Returns: - List of JSON event strings for Nova Sonic. - """ - events = [] - - for message in messages: - role = message["role"].upper() # Convert to ASSISTANT or USER - content_blocks = message.get("content", []) - - # Extract text content from content blocks - text_parts = [] - for block in content_blocks: - if "text" in block: - text_parts.append(block["text"]) - - # Combine all text parts - if text_parts: - combined_text = "\n".join(text_parts) - content_name = str(uuid.uuid4()) - - # Add contentStart, textInput, and contentEnd events - events.extend( - [ - self._get_text_content_start_event(content_name, role), - self._get_text_input_event(content_name, combined_text), - self._get_content_end_event(content_name), - ] - ) - - return events - - def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: - """Generate text content start event.""" - return json.dumps( - { - "event": { - "contentStart": { - "promptName": self._connection_id, - "contentName": content_name, - "type": "TEXT", - "role": role, - "interactive": True, - "textInputConfiguration": NOVA_TEXT_CONFIG, - } - } - } - ) - - def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: - """Generate tool content start event.""" - return json.dumps( - { - "event": { - "contentStart": { - "promptName": self._connection_id, - "contentName": content_name, - "interactive": False, - "type": "TOOL", - "role": "TOOL", - "toolResultInputConfiguration": { - "toolUseId": tool_use_id, - "type": "TEXT", - "textInputConfiguration": NOVA_TEXT_CONFIG, - }, - } - } - } - ) - - def _get_text_input_event(self, content_name: str, text: str) -> str: - """Generate text input event.""" - return json.dumps( - {"event": {"textInput": {"promptName": self._connection_id, "contentName": content_name, "content": text}}} - ) - - def _get_tool_result_event(self, content_name: str, result: dict[str, Any]) -> str: - """Generate tool result event.""" - return json.dumps( - { - "event": { - "toolResult": { - "promptName": self._connection_id, - "contentName": content_name, - "content": json.dumps(result), - } - } - } - ) - - def _get_content_end_event(self, content_name: str) -> str: - """Generate content end event.""" - return json.dumps({"event": {"contentEnd": {"promptName": self._connection_id, "contentName": content_name}}}) - - def _get_prompt_end_event(self) -> str: - """Generate prompt end event.""" - return json.dumps({"event": {"promptEnd": {"promptName": self._connection_id}}}) - - def _get_connection_end_event(self) -> str: - """Generate connection end event.""" - return json.dumps({"event": {"connectionEnd": {}}}) - - async def _send_nova_events(self, events: list[str]) -> None: - """Send event JSON string to Nova Sonic stream. - - A lock is used to send events in sequence when required (e.g., tool result start, content, and end). - - Args: - events: Jsonified events. - """ - async with self._send_lock: - for event in events: - bytes_data = event.encode("utf-8") - chunk = InvokeModelWithBidirectionalStreamInputChunk( - value=BidirectionalInputPayloadPart(bytes_=bytes_data) - ) - await self._stream.input_stream.send(chunk) - logger.debug("nova sonic event sent successfully") diff --git a/src/strands/experimental/bidi/models/openai.py b/src/strands/experimental/bidi/models/openai.py deleted file mode 100644 index bfe3ad533..000000000 --- a/src/strands/experimental/bidi/models/openai.py +++ /dev/null @@ -1,816 +0,0 @@ -"""OpenAI Realtime API provider for Strands bidirectional streaming. - -Provides real-time audio and text communication through OpenAI's Realtime API -with WebSocket connections, voice activity detection, and function calling. -""" - -import asyncio -import json -import logging -import os -import time -import uuid -from typing import Any, AsyncGenerator, Literal, cast - -import websockets -from websockets import ClientConnection - -from ....types._events import ToolResultEvent, ToolUseStreamEvent -from ....types.content import Messages -from ....types.tools import ToolResult, ToolSpec, ToolUse -from .._async import stop_all -from ..types.bidi_model import AudioConfig -from ..types.events import ( - AudioSampleRate, - BidiAudioInputEvent, - BidiAudioStreamEvent, - BidiConnectionStartEvent, - BidiInputEvent, - BidiInterruptionEvent, - BidiOutputEvent, - BidiResponseCompleteEvent, - BidiResponseStartEvent, - BidiTextInputEvent, - BidiTranscriptStreamEvent, - BidiUsageEvent, - ModalityUsage, - Role, - StopReason, -) -from .bidi_model import BidiModel, BidiModelTimeoutError - -logger = logging.getLogger(__name__) - -# Test idle_timeout_ms - -# OpenAI Realtime API configuration -OPENAI_MAX_TIMEOUT_S = 3000 # 50 minutes -"""Max timeout before closing connection. - -OpenAI documents a 60 minute limit on realtime sessions -(https://platform.openai.com/docs/guides/realtime-conversations#session-lifecycle-events). However, OpenAI does not -emit any warnings when approaching the limit. As a workaround, we configure a max timeout client side to gracefully -handle the connection closure. We set the max to 50 minutes to provide enough buffer before hitting the real limit. -""" -OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" -DEFAULT_MODEL = "gpt-realtime" -DEFAULT_SAMPLE_RATE = 24000 - -DEFAULT_SESSION_CONFIG = { - "type": "realtime", - "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", - "output_modalities": ["audio"], - "audio": { - "input": { - "format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, - "transcription": {"model": "gpt-4o-transcribe"}, - "turn_detection": { - "type": "server_vad", - "threshold": 0.5, - "prefix_padding_ms": 300, - "silence_duration_ms": 500, - }, - }, - "output": {"format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, "voice": "alloy"}, - }, -} - - -class BidiOpenAIRealtimeModel(BidiModel): - """OpenAI Realtime API implementation for bidirectional streaming. - - Combines model configuration and connection state in a single class. - Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, - function calling, and event conversion to Strands format. - """ - - _websocket: ClientConnection - _start_time: int - - def __init__( - self, - model_id: str = DEFAULT_MODEL, - provider_config: dict[str, Any] | None = None, - client_config: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - """Initialize OpenAI Realtime bidirectional model. - - Args: - model_id: Model identifier (default: gpt-realtime) - provider_config: Model behavior (audio, instructions, turn_detection, etc.) - client_config: Authentication (api_key, organization, project) - Falls back to OPENAI_API_KEY, OPENAI_ORGANIZATION, OPENAI_PROJECT env vars - **kwargs: Reserved for future parameters. - - """ - # Store model ID - self.model_id = model_id - - # Resolve client config with defaults and env vars - self._client_config = self._resolve_client_config(client_config or {}) - - # Resolve provider config with defaults - self.config = self._resolve_provider_config(provider_config or {}) - - # Store client config values for later use - self.api_key = self._client_config["api_key"] - self.organization = self._client_config.get("organization") - self.project = self._client_config.get("project") - self.timeout_s = self._client_config["timeout_s"] - - if self.timeout_s > OPENAI_MAX_TIMEOUT_S: - raise ValueError( - f"timeout_s=<{self.timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit" - ) - - # Connection state (initialized in start()) - self._connection_id: str | None = None - - self._function_call_buffer: dict[str, Any] = {} - - logger.debug("model=<%s> | openai realtime model initialized", model_id) - - def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: - """Resolve client config with env var fallback (config takes precedence).""" - resolved = config.copy() - - if "api_key" not in resolved: - resolved["api_key"] = os.getenv("OPENAI_API_KEY") - - if not resolved.get("api_key"): - raise ValueError( - "OpenAI API key is required. Provide via client_config={'api_key': '...'} " - "or set OPENAI_API_KEY environment variable." - ) - if "organization" not in resolved: - env_org = os.getenv("OPENAI_ORGANIZATION") - if env_org: - resolved["organization"] = env_org - - if "project" not in resolved: - env_project = os.getenv("OPENAI_PROJECT") - if env_project: - resolved["project"] = env_project - - if "timeout_s" not in resolved: - resolved["timeout_s"] = OPENAI_MAX_TIMEOUT_S - - return resolved - - def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: - """Merge user config with defaults (user takes precedence).""" - # Extract voice from provider-specific audio.output.voice if present - provider_voice = None - if "audio" in config and isinstance(config["audio"], dict): - if "output" in config["audio"] and isinstance(config["audio"]["output"], dict): - provider_voice = config["audio"]["output"].get("voice") - - # Define default audio configuration - default_audio: AudioConfig = { - "input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), - "output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), - "channels": 1, - "format": "pcm", - "voice": provider_voice or "alloy", - } - - user_audio = config.get("audio", {}) - merged_audio = {**default_audio, **user_audio} - - resolved = { - "audio": merged_audio, - **{k: v for k, v in config.items() if k != "audio"}, - } - - if user_audio: - logger.debug("audio_config | merged user-provided config with defaults") - else: - logger.debug("audio_config | using default OpenAI Realtime audio configuration") - - return resolved - - async def start( - self, - system_prompt: str | None = None, - tools: list[ToolSpec] | None = None, - messages: Messages | None = None, - **kwargs: Any, - ) -> None: - """Establish bidirectional connection to OpenAI Realtime API. - - Args: - system_prompt: System instructions for the model. - tools: List of tools available to the model. - messages: Conversation history to initialize with. - **kwargs: Additional configuration options. - """ - if self._connection_id: - raise RuntimeError("model already started | call stop before starting again") - - logger.debug("openai realtime connection starting") - - # Initialize connection state - self._connection_id = str(uuid.uuid4()) - self._start_time = int(time.time()) - - self._function_call_buffer = {} - - # Establish WebSocket connection - url = f"{OPENAI_REALTIME_URL}?model={self.model_id}" - - headers = [("Authorization", f"Bearer {self.api_key}")] - if self.organization: - headers.append(("OpenAI-Organization", self.organization)) - if self.project: - headers.append(("OpenAI-Project", self.project)) - - self._websocket = await websockets.connect(url, additional_headers=headers) - logger.debug("connection_id=<%s> | websocket connected successfully", self._connection_id) - - # Configure session - session_config = self._build_session_config(system_prompt, tools) - await self._send_event({"type": "session.update", "session": session_config}) - - # Add conversation history if provided - if messages: - await self._add_conversation_history(messages) - - def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent: - """Create standardized transcript event. - - Args: - text: The transcript text - role: The role (will be normalized to lowercase) - is_final: Whether this is the final transcript - """ - # Normalize role to lowercase and ensure it's either "user" or "assistant" - normalized_role = role.lower() if isinstance(role, str) else "assistant" - if normalized_role not in ["user", "assistant"]: - normalized_role = "assistant" - - return BidiTranscriptStreamEvent( - delta={"text": text}, - text=text, - role=cast(Role, normalized_role), - is_final=is_final, - current_transcript=text if is_final else None, - ) - - def _create_voice_activity_event(self, activity_type: str) -> BidiInterruptionEvent | None: - """Create standardized interruption event for voice activity.""" - # Only speech_started triggers interruption - if activity_type == "speech_started": - return BidiInterruptionEvent(reason="user_speech") - # Other voice activity events are logged but don't create events - return None - - def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict[str, Any]: - """Build session configuration for OpenAI Realtime API.""" - config: dict[str, Any] = DEFAULT_SESSION_CONFIG.copy() - - if system_prompt: - config["instructions"] = system_prompt - - if tools: - config["tools"] = self._convert_tools_to_openai_format(tools) - - # Apply user-provided session configuration - supported_params = { - "type", - "output_modalities", - "instructions", - "voice", - "tools", - "tool_choice", - "input_audio_format", - "output_audio_format", - "input_audio_transcription", - "turn_detection", - } - - for key, value in self.config.items(): - if key == "audio": - continue - elif key in supported_params: - config[key] = value - else: - logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) - - audio_config = self.config["audio"] - - if "voice" in audio_config: - config.setdefault("audio", {}).setdefault("output", {})["voice"] = audio_config["voice"] - - if "input_rate" in audio_config: - config.setdefault("audio", {}).setdefault("input", {}).setdefault("format", {})["rate"] = audio_config[ - "input_rate" - ] - - if "output_rate" in audio_config: - config.setdefault("audio", {}).setdefault("output", {}).setdefault("format", {})["rate"] = audio_config[ - "output_rate" - ] - - return config - - def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: - """Convert Strands tool specifications to OpenAI Realtime API format.""" - openai_tools = [] - - for tool in tools: - input_schema = tool["inputSchema"] - if "json" in input_schema: - schema = ( - json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] - ) - else: - schema = input_schema - - # OpenAI Realtime API expects flat structure, not nested under "function" - openai_tool = { - "type": "function", - "name": tool["name"], - "description": tool["description"], - "parameters": schema, - } - openai_tools.append(openai_tool) - - return openai_tools - - async def _add_conversation_history(self, messages: Messages) -> None: - """Add conversation history to the session. - - Converts agent message history to OpenAI Realtime API format using - conversation.item.create events for each message. - - Note: OpenAI Realtime API has a 32-character limit on call_id, so we truncate - UUIDs consistently to ensure tool calls and their results match. - - Args: - messages: List of conversation messages with role and content. - """ - # Track tool call IDs to ensure consistency between calls and results - call_id_map: dict[str, str] = {} - - # First pass: collect all tool call IDs - for message in messages: - for block in message.get("content", []): - if "toolUse" in block: - tool_use = block["toolUse"] - original_id = tool_use["toolUseId"] - call_id = original_id[:32] - call_id_map[original_id] = call_id - - # Second pass: send messages - for message in messages: - role = message["role"] - content_blocks = message.get("content", []) - - # Build content array for OpenAI format - openai_content = [] - - for block in content_blocks: - if "text" in block: - # Text content - use appropriate type based on role - # User messages use "input_text", assistant messages use "output_text" - if role == "user": - openai_content.append({"type": "input_text", "text": block["text"]}) - else: # assistant - openai_content.append({"type": "output_text", "text": block["text"]}) - elif "toolUse" in block: - # Tool use - create as function_call item - tool_use = block["toolUse"] - original_id = tool_use["toolUseId"] - # Use pre-mapped call_id - call_id = call_id_map[original_id] - - tool_item = { - "type": "conversation.item.create", - "item": { - "type": "function_call", - "call_id": call_id, - "name": tool_use["name"], - "arguments": json.dumps(tool_use["input"]), - }, - } - await self._send_event(tool_item) - continue # Tool use is sent separately, not in message content - elif "toolResult" in block: - # Tool result - create as function_call_output item - tool_result = block["toolResult"] - original_id = tool_result["toolUseId"] - - # Validate content types and serialize, preserving structure - result_output = "" - if "content" in tool_result: - # First validate all content types are supported - for result_block in tool_result["content"]: - if "text" not in result_block and "json" not in result_block: - # Unsupported content type - raise error - raise ValueError( - f"tool_use_id=<{original_id}>, content_types=<{list(result_block.keys())}> | " - f"Content type not supported by OpenAI Realtime API" - ) - - # Preserve structure by JSON-dumping the entire content array - result_output = json.dumps(tool_result["content"]) - - # Use mapped call_id if available, otherwise skip orphaned result - if original_id not in call_id_map: - continue # Skip this tool result since we don't have the call - - call_id = call_id_map[original_id] - - result_item = { - "type": "conversation.item.create", - "item": { - "type": "function_call_output", - "call_id": call_id, - "output": result_output, - }, - } - await self._send_event(result_item) - continue # Tool result is sent separately, not in message content - - # Only create message item if there's text content - if openai_content: - conversation_item = { - "type": "conversation.item.create", - "item": {"type": "message", "role": role, "content": openai_content}, - } - await self._send_event(conversation_item) - - logger.debug("message_count=<%d> | conversation history added to openai session", len(messages)) - - async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: - """Receive OpenAI events and convert to Strands TypedEvent format.""" - if not self._connection_id: - raise RuntimeError("model not started | call start before sending/receiving") - - yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) - - while True: - duration = time.time() - self._start_time - if duration >= self.timeout_s: - raise BidiModelTimeoutError(f"timeout_s=<{self.timeout_s}>") - - try: - message = await asyncio.wait_for(self._websocket.recv(), timeout=10) - except asyncio.TimeoutError: - continue - - openai_event = json.loads(message) - - for event in self._convert_openai_event(openai_event) or []: - yield event - - def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutputEvent] | None: - """Convert OpenAI events to Strands TypedEvent format.""" - event_type = openai_event.get("type") - - # Turn start - response begins - if event_type == "response.created": - response = openai_event.get("response", {}) - response_id = response.get("id", str(uuid.uuid4())) - return [BidiResponseStartEvent(response_id=response_id)] - - # Audio output - elif event_type == "response.output_audio.delta": - # Audio is already base64 string from OpenAI - # Use the resolved output sample rate from our merged configuration - sample_rate = self.config["audio"]["output_rate"] - - # Channels from config is guaranteed to be 1 or 2 - channels = cast(Literal[1, 2], self.config["audio"]["channels"]) - return [ - BidiAudioStreamEvent( - audio=openai_event["delta"], - format="pcm", - sample_rate=sample_rate, - channels=channels, - ) - ] - - # Assistant text output events - combine multiple similar events - elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: - role = openai_event.get("role", "assistant") - return [ - self._create_text_event( - openai_event["delta"], role.lower() if isinstance(role, str) else "assistant", is_final=False - ) - ] - - elif event_type in ["response.output_audio_transcript.done"]: - role = openai_event.get("role", "assistant").lower() - return [self._create_text_event(openai_event["transcript"], role)] - - elif event_type in ["response.output_text.done"]: - role = openai_event.get("role", "assistant").lower() - return [self._create_text_event(openai_event["text"], role)] - - # User transcription events - combine multiple similar events - elif event_type in [ - "conversation.item.input_audio_transcription.delta", - "conversation.item.input_audio_transcription.completed", - ]: - text_key = "delta" if "delta" in event_type else "transcript" - text = openai_event.get(text_key, "") - role = openai_event.get("role", "user") - is_final = "completed" in event_type - return ( - [self._create_text_event(text, role.lower() if isinstance(role, str) else "user", is_final=is_final)] - if text.strip() - else None - ) - - elif event_type == "conversation.item.input_audio_transcription.segment": - segment_data = openai_event.get("segment", {}) - text = segment_data.get("text", "") - role = segment_data.get("role", "user") - return ( - [self._create_text_event(text, role.lower() if isinstance(role, str) else "user")] - if text.strip() - else None - ) - - elif event_type == "conversation.item.input_audio_transcription.failed": - error_info = openai_event.get("error", {}) - logger.warning("error=<%s> | openai transcription failed", error_info.get("message", "unknown error")) - return None - - # Function call processing - elif event_type == "response.function_call_arguments.delta": - call_id = openai_event.get("call_id") - delta = openai_event.get("delta", "") - if call_id: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} - else: - self._function_call_buffer[call_id]["arguments"] += delta - return None - - elif event_type == "response.function_call_arguments.done": - call_id = openai_event.get("call_id") - if call_id and call_id in self._function_call_buffer: - function_call = self._function_call_buffer[call_id] - try: - tool_use: ToolUse = { - "toolUseId": call_id, - "name": function_call["name"], - "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, - } - del self._function_call_buffer[call_id] - # Return ToolUseStreamEvent for consistency with standard agent - return [ToolUseStreamEvent(delta={"toolUse": tool_use}, current_tool_use=dict(tool_use))] - except (json.JSONDecodeError, KeyError) as e: - logger.warning("call_id=<%s>, error=<%s> | error parsing function arguments", call_id, e) - del self._function_call_buffer[call_id] - return None - - # Voice activity detection - speech_started triggers interruption - elif event_type == "input_audio_buffer.speech_started": - # This is the primary interruption signal - handle it first - return [BidiInterruptionEvent(reason="user_speech")] - - # Response cancelled - handle interruption - elif event_type == "response.cancelled": - response = openai_event.get("response", {}) - response_id = response.get("id", "unknown") - logger.debug("response_id=<%s> | openai response cancelled", response_id) - return [BidiResponseCompleteEvent(response_id=response_id, stop_reason="interrupted")] - - # Turn complete and usage - response finished - elif event_type == "response.done": - response = openai_event.get("response", {}) - response_id = response.get("id", "unknown") - status = response.get("status", "completed") - usage = response.get("usage") - - # Map OpenAI status to our stop_reason - stop_reason_map = { - "completed": "complete", - "cancelled": "interrupted", - "failed": "error", - "incomplete": "interrupted", - } - - # Build list of events to return - events: list[Any] = [] - - # Always add response complete event - events.append( - BidiResponseCompleteEvent( - response_id=response_id, - stop_reason=cast(StopReason, stop_reason_map.get(status, "complete")), - ), - ) - - # Add usage event if available - if usage: - input_details = usage.get("input_token_details", {}) - output_details = usage.get("output_token_details", {}) - - # Build modality details - modality_details = [] - - # Text modality - text_input = input_details.get("text_tokens", 0) - text_output = output_details.get("text_tokens", 0) - if text_input > 0 or text_output > 0: - modality_details.append( - {"modality": "text", "input_tokens": text_input, "output_tokens": text_output} - ) - - # Audio modality - audio_input = input_details.get("audio_tokens", 0) - audio_output = output_details.get("audio_tokens", 0) - if audio_input > 0 or audio_output > 0: - modality_details.append( - {"modality": "audio", "input_tokens": audio_input, "output_tokens": audio_output} - ) - - # Image modality - image_input = input_details.get("image_tokens", 0) - if image_input > 0: - modality_details.append({"modality": "image", "input_tokens": image_input, "output_tokens": 0}) - - # Cached tokens - cached_tokens = input_details.get("cached_tokens", 0) - - # Add usage event - events.append( - BidiUsageEvent( - input_tokens=usage.get("input_tokens", 0), - output_tokens=usage.get("output_tokens", 0), - total_tokens=usage.get("total_tokens", 0), - modality_details=cast(list[ModalityUsage], modality_details) if modality_details else None, - cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None, - ) - ) - - # Return list of events - return events - - # Lifecycle events (log only) - combine multiple similar events - elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: - item = openai_event.get("item", {}) - action = "retrieved" if "retrieve" in event_type else "added" - logger.debug("action=<%s>, item_id=<%s> | openai conversation item event", action, item.get("id")) - return None - - elif event_type == "conversation.item.done": - logger.debug("item_id=<%s> | openai conversation item done", openai_event.get("item", {}).get("id")) - return None - - # Response output events - combine similar events - elif event_type in [ - "response.output_item.added", - "response.output_item.done", - "response.content_part.added", - "response.content_part.done", - ]: - item_data = openai_event.get("item") or openai_event.get("part") - logger.debug( - "event_type=<%s>, item_id=<%s> | openai output event", - event_type, - item_data.get("id") if item_data else "unknown", - ) - - # Track function call names from response.output_item.added - if event_type == "response.output_item.added": - item = openai_event.get("item", {}) - if item.get("type") == "function_call": - call_id = item.get("call_id") - function_name = item.get("name") - if call_id and function_name: - if call_id not in self._function_call_buffer: - self._function_call_buffer[call_id] = { - "call_id": call_id, - "name": function_name, - "arguments": "", - } - else: - self._function_call_buffer[call_id]["name"] = function_name - return None - - # Session/buffer events - combine simple log-only events - elif event_type in [ - "input_audio_buffer.committed", - "input_audio_buffer.cleared", - "session.created", - "session.updated", - ]: - logger.debug("event_type=<%s> | openai event received", event_type) - return None - - elif event_type == "error": - error_data = openai_event.get("error", {}) - error_code = error_data.get("code", "") - - # Suppress expected errors that don't affect session state - if error_code == "response_cancel_not_active": - # This happens when trying to cancel a response that's not active - # It's safe to ignore as the session remains functional - logger.debug("openai response cancel attempted when no response active") - return None - - # Log other errors - logger.error("error=<%s> | openai realtime error", error_data) - return None - - else: - logger.debug("event_type=<%s> | unhandled openai event type", event_type) - return None - - async def send( - self, - content: BidiInputEvent | ToolResultEvent, - ) -> None: - """Unified send method for all content types. Sends the given content to OpenAI. - - Dispatches to appropriate internal handler based on content type. - - Args: - content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). - - Raises: - ValueError: If content type not supported (e.g., image content). - """ - if not self._connection_id: - raise RuntimeError("model not started | call start before sending") - - # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first - if isinstance(content, BidiTextInputEvent): - await self._send_text_content(content.text) - elif isinstance(content, BidiAudioInputEvent): - await self._send_audio_content(content) - elif isinstance(content, ToolResultEvent): - tool_result = content.get("tool_result") - if tool_result: - await self._send_tool_result(tool_result) - else: - raise ValueError(f"content_type={type(content)} | content not supported") - - async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: - """Internal: Send audio content to OpenAI for processing.""" - # Audio is already base64 encoded in the event - await self._send_event({"type": "input_audio_buffer.append", "audio": audio_input.audio}) - - async def _send_text_content(self, text: str) -> None: - """Internal: Send text content to OpenAI for processing.""" - item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} - await self._send_event({"type": "conversation.item.create", "item": item_data}) - await self._send_event({"type": "response.create"}) - - async def _send_interrupt(self) -> None: - """Internal: Send interruption signal to OpenAI.""" - await self._send_event({"type": "response.cancel"}) - - async def _send_tool_result(self, tool_result: ToolResult) -> None: - """Internal: Send tool result back to OpenAI.""" - tool_use_id = tool_result.get("toolUseId") - - logger.debug("tool_use_id=<%s> | sending openai tool result", tool_use_id) - - # Validate content types and serialize, preserving structure - result_output = "" - if "content" in tool_result: - # First validate all content types are supported - for block in tool_result["content"]: - if "text" not in block and "json" not in block: - # Unsupported content type - raise error - raise ValueError( - f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " - f"Content type not supported by OpenAI Realtime API" - ) - - # Preserve structure by JSON-dumping the entire content array - result_output = json.dumps(tool_result["content"]) - - item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_output} - await self._send_event({"type": "conversation.item.create", "item": item_data}) - await self._send_event({"type": "response.create"}) - - async def stop(self) -> None: - """Close session and cleanup resources.""" - logger.debug("openai realtime connection cleanup starting") - - async def stop_websocket() -> None: - if not hasattr(self, "_websocket"): - return - - await self._websocket.close() - - async def stop_connection() -> None: - self._connection_id = None - - await stop_all(stop_websocket, stop_connection) - - logger.debug("openai realtime connection closed") - - async def _send_event(self, event: dict[str, Any]) -> None: - """Send event to OpenAI via WebSocket.""" - message = json.dumps(event) - await self._websocket.send(message) - logger.debug("event_type=<%s> | openai event sent", event.get("type")) diff --git a/src/strands/experimental/bidi/types/bidi_model.py b/src/strands/experimental/bidi/types/bidi_model.py deleted file mode 100644 index de41de1a9..000000000 --- a/src/strands/experimental/bidi/types/bidi_model.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Model-related type definitions for bidirectional streaming. - -Defines types and configurations that are central to model providers, -including audio configuration that models use to specify their audio -processing requirements. -""" - -from typing import TypedDict - -from .events import AudioChannel, AudioFormat, AudioSampleRate - - -class AudioConfig(TypedDict, total=False): - """Audio configuration for bidirectional streaming models. - - Defines standard audio parameters that model providers use to specify - their audio processing requirements. All fields are optional to support - models that may not use audio or only need specific parameters. - - Model providers build this configuration by merging user-provided values - with their own defaults. The resulting configuration is then used by - audio I/O implementations to configure hardware appropriately. - - Attributes: - input_rate: Input sample rate in Hz (e.g., 16000, 24000, 48000) - output_rate: Output sample rate in Hz (e.g., 16000, 24000, 48000) - channels: Number of audio channels (1=mono, 2=stereo) - format: Audio encoding format - voice: Voice identifier for text-to-speech (e.g., "alloy", "matthew") - """ - - input_rate: AudioSampleRate - output_rate: AudioSampleRate - channels: AudioChannel - format: AudioFormat - voice: str diff --git a/src/strands/experimental/bidi/types/events.py b/src/strands/experimental/bidi/types/events.py index 7ea2b6345..d9905c16b 100644 --- a/src/strands/experimental/bidi/types/events.py +++ b/src/strands/experimental/bidi/types/events.py @@ -25,7 +25,7 @@ from ....types.streaming import ContentBlockDelta if TYPE_CHECKING: - from ..models.bidi_model import BidiModelTimeoutError + from ..models.model import BidiModelTimeoutError AudioChannel = Literal[1, 2] """Number of audio channels. diff --git a/tests/strands/experimental/bidi/models/test_gemini_live.py b/tests/strands/experimental/bidi/models/test_gemini_live.py index a880bb223..c92211816 100644 --- a/tests/strands/experimental/bidi/models/test_gemini_live.py +++ b/tests/strands/experimental/bidi/models/test_gemini_live.py @@ -13,7 +13,7 @@ import pytest from google.genai import types as genai_types -from strands.experimental.bidi.models.bidi_model import BidiModelTimeoutError +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.models.gemini_live import BidiGeminiLiveModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_novasonic.py index 39524e434..7ec0c32a1 100644 --- a/tests/strands/experimental/bidi/models/test_novasonic.py +++ b/tests/strands/experimental/bidi/models/test_novasonic.py @@ -13,10 +13,10 @@ import pytest_asyncio from aws_sdk_bedrock_runtime.models import ModelTimeoutException, ValidationException -from strands.experimental.bidi.models.novasonic import ( +from strands.experimental.bidi.models.nova_sonic import ( BidiNovaSonicModel, ) -from strands.experimental.bidi.models.bidi_model import BidiModelTimeoutError +from strands.experimental.bidi.models.model import BidiModelTimeoutError from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai.py index 85a1cc097..5b3d627fd 100644 --- a/tests/strands/experimental/bidi/models/test_openai.py +++ b/tests/strands/experimental/bidi/models/test_openai.py @@ -14,8 +14,8 @@ import pytest -from strands.experimental.bidi.models.bidi_model import BidiModelTimeoutError -from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel +from strands.experimental.bidi.models.model import BidiModelTimeoutError +from strands.experimental.bidi.models.openai_realtime import BidiOpenAIRealtimeModel from strands.experimental.bidi.types.events import ( BidiAudioInputEvent, BidiAudioStreamEvent, From 0dd05fe99ecdfb9d8243686adefed6ce0b881780 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 16:09:27 -0500 Subject: [PATCH 2/6] rename files --- src/strands/experimental/bidi/models/model.py | 130 +++ .../experimental/bidi/models/nova_sonic.py | 760 ++++++++++++++++ .../bidi/models/openai_realtime.py | 816 ++++++++++++++++++ src/strands/experimental/bidi/types/model.py | 36 + 4 files changed, 1742 insertions(+) create mode 100644 src/strands/experimental/bidi/models/model.py create mode 100644 src/strands/experimental/bidi/models/nova_sonic.py create mode 100644 src/strands/experimental/bidi/models/openai_realtime.py create mode 100644 src/strands/experimental/bidi/types/model.py diff --git a/src/strands/experimental/bidi/models/model.py b/src/strands/experimental/bidi/models/model.py new file mode 100644 index 000000000..0d0da63d2 --- /dev/null +++ b/src/strands/experimental/bidi/models/model.py @@ -0,0 +1,130 @@ +"""Bidirectional streaming model interface. + +Defines the abstract interface for models that support real-time bidirectional +communication with persistent connections. Unlike traditional request-response +models, bidirectional models maintain an open connection for streaming audio, +text, and tool interactions. + +Features: +- Persistent connection management with connect/close lifecycle +- Real-time bidirectional communication (send and receive simultaneously) +- Provider-agnostic event normalization +- Support for audio, text, image, and tool result streaming +""" + +import logging +from typing import Any, AsyncIterable, Protocol + +from ....types._events import ToolResultEvent +from ....types.content import Messages +from ....types.tools import ToolSpec +from ..types.events import ( + BidiInputEvent, + BidiOutputEvent, +) + +logger = logging.getLogger(__name__) + + +class BidiModel(Protocol): + """Protocol for bidirectional streaming models. + + This interface defines the contract for models that support persistent streaming + connections with real-time audio and text communication. Implementations handle + provider-specific protocols while exposing a standardized event-based API. + + Attributes: + config: Configuration dictionary with provider-specific settings. + """ + + config: dict[str, Any] + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs: Any, + ) -> None: + """Establish a persistent streaming connection with the model. + + Opens a bidirectional connection that remains active for real-time communication. + The connection supports concurrent sending and receiving of events until explicitly + closed. Must be called before any send() or receive() operations. + + Args: + system_prompt: System instructions to configure model behavior. + tools: Tool specifications that the model can invoke during the conversation. + messages: Initial conversation history to provide context. + **kwargs: Provider-specific configuration options. + """ + ... + + async def stop(self) -> None: + """Close the streaming connection and release resources. + + Terminates the active bidirectional connection and cleans up any associated + resources such as network connections, buffers, or background tasks. After + calling close(), the model instance cannot be used until start() is called again. + """ + ... + + def receive(self) -> AsyncIterable[BidiOutputEvent]: + """Receive streaming events from the model. + + Continuously yields events from the model as they arrive over the connection. + Events are normalized to a provider-agnostic format for uniform processing. + This method should be called in a loop or async task to process model responses. + + The stream continues until the connection is closed or an error occurs. + + Yields: + BidiOutputEvent: Standardized event objects containing audio output, + transcripts, tool calls, or control signals. + """ + ... + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Send content to the model over the active connection. + + Transmits user input or tool results to the model during an active streaming + session. Supports multiple content types including text, audio, images, and + tool execution results. Can be called multiple times during a conversation. + + Args: + content: The content to send. Must be one of: + - BidiTextInputEvent: Text message from the user + - BidiAudioInputEvent: Audio data for speech input + - BidiImageInputEvent: Image data for visual understanding + - ToolResultEvent: Result from a tool execution + + Example: + await model.send(BidiTextInputEvent(text="Hello", role="user")) + await model.send(BidiAudioInputEvent(audio=bytes, format="pcm", sample_rate=16000, channels=1)) + await model.send(BidiImageInputEvent(image=bytes, mime_type="image/jpeg", encoding="raw")) + await model.send(ToolResultEvent(tool_result)) + """ + ... + + +class BidiModelTimeoutError(Exception): + """Model timeout error. + + Bidirectional models are often configured with a connection time limit. Nova sonic for example keeps the connection + open for 8 minutes max. Upon receiving a timeout, the agent loop is configured to restart the model connection so as + to create a seamless, uninterrupted experience for the user. + """ + + def __init__(self, message: str, **restart_config: Any) -> None: + """Initialize error. + + Args: + message: Timeout message from model. + **restart_config: Configure restart specific behaviors in the call to model start. + """ + super().__init__(self, message) + + self.restart_config = restart_config diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py new file mode 100644 index 000000000..262b37240 --- /dev/null +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -0,0 +1,760 @@ +"""Nova Sonic bidirectional model provider for real-time streaming conversations. + +Implements the BidiModel interface for Amazon's Nova Sonic, handling the +complex event sequencing and audio processing required by Nova Sonic's +InvokeModelWithBidirectionalStream protocol. + +Nova Sonic specifics: +- Hierarchical event sequences: connectionStart → promptStart → content streaming +- Base64-encoded audio format with hex encoding +- Tool execution with content containers and identifier tracking +- 8-minute connection limits with proper cleanup sequences +- Interruption detection through stopReason events +""" + +import asyncio +import base64 +import json +import logging +import uuid +from typing import Any, AsyncGenerator, cast + +import boto3 +from aws_sdk_bedrock_runtime.client import BedrockRuntimeClient, InvokeModelWithBidirectionalStreamOperationInput +from aws_sdk_bedrock_runtime.config import Config, HTTPAuthSchemeResolver, SigV4AuthScheme +from aws_sdk_bedrock_runtime.models import ( + BidirectionalInputPayloadPart, + InvokeModelWithBidirectionalStreamInputChunk, + ModelTimeoutException, + ValidationException, +) +from smithy_aws_core.identity.static import StaticCredentialsResolver +from smithy_core.aio.eventstream import DuplexEventStream +from smithy_core.shapes import ShapeID + +from ....types._events import ToolResultEvent, ToolUseStreamEvent +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec, ToolUse +from .._async import stop_all +from ..types.bidi_model import AudioConfig +from ..types.events import ( + AudioChannel, + AudioSampleRate, + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionStartEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, +) +from .model import BidiModel, BidiModelTimeoutError + +logger = logging.getLogger(__name__) + +# Nova Sonic configuration constants +NOVA_INFERENCE_CONFIG = {"maxTokens": 1024, "topP": 0.9, "temperature": 0.7} + +NOVA_AUDIO_INPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "audioType": "SPEECH", + "encoding": "base64", +} + +NOVA_AUDIO_OUTPUT_CONFIG = { + "mediaType": "audio/lpcm", + "sampleRateHertz": 16000, + "sampleSizeBits": 16, + "channelCount": 1, + "voiceId": "matthew", + "encoding": "base64", + "audioType": "SPEECH", +} + +NOVA_TEXT_CONFIG = {"mediaType": "text/plain"} +NOVA_TOOL_CONFIG = {"mediaType": "application/json"} + + +class BidiNovaSonicModel(BidiModel): + """Nova Sonic implementation for bidirectional streaming. + + Combines model configuration and connection state in a single class. + Manages Nova Sonic's complex event sequencing, audio format conversion, and + tool execution patterns while providing the standard BidiModel interface. + + Attributes: + _stream: open bedrock stream to nova sonic. + """ + + _stream: DuplexEventStream + + def __init__( + self, + model_id: str = "amazon.nova-sonic-v1:0", + provider_config: dict[str, Any] | None = None, + client_config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Initialize Nova Sonic bidirectional model. + + Args: + model_id: Model identifier (default: amazon.nova-sonic-v1:0) + provider_config: Model behavior (audio, inference settings) + client_config: AWS authentication (boto_session OR region, not both) + **kwargs: Reserved for future parameters. + """ + # Store model ID + self.model_id = model_id + + # Resolve client config with defaults + self._client_config = self._resolve_client_config(client_config or {}) + + # Resolve provider config with defaults + self.config = self._resolve_provider_config(provider_config or {}) + + # Store session and region for later use + self._session = self._client_config["boto_session"] + self.region = self._client_config["region"] + + # Track API-provided identifiers + self._connection_id: str | None = None + self._audio_content_name: str | None = None + self._current_completion_id: str | None = None + + # Indicates if model is done generating transcript + self._generation_stage: str | None = None + + # Ensure certain events are sent in sequence when required + self._send_lock = asyncio.Lock() + + logger.debug("model_id=<%s> | nova sonic model initialized", model_id) + + def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Resolve AWS client config (creates boto session if needed).""" + if "boto_session" in config and "region" in config: + raise ValueError("Cannot specify both 'boto_session' and 'region' in client_config") + + resolved = config.copy() + + # Create boto session if not provided + if "boto_session" not in resolved: + resolved["boto_session"] = boto3.Session() + + # Resolve region from session or use default + if "region" not in resolved: + resolved["region"] = resolved["boto_session"].region_name or "us-east-1" + + return resolved + + def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Merge user config with defaults (user takes precedence).""" + # Define default audio configuration + default_audio_config: AudioConfig = { + "input_rate": cast(AudioSampleRate, NOVA_AUDIO_INPUT_CONFIG["sampleRateHertz"]), + "output_rate": cast(AudioSampleRate, NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"]), + "channels": cast(AudioChannel, NOVA_AUDIO_INPUT_CONFIG["channelCount"]), + "format": "pcm", + "voice": cast(str, NOVA_AUDIO_OUTPUT_CONFIG["voiceId"]), + } + + user_audio_config = config.get("audio", {}) + merged_audio = {**default_audio_config, **user_audio_config} + + resolved = { + "audio": merged_audio, + **{k: v for k, v in config.items() if k != "audio"}, + } + + if user_audio_config: + logger.debug("audio_config | merged user-provided config with defaults") + else: + logger.debug("audio_config | using default Nova Sonic audio configuration") + + return resolved + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs: Any, + ) -> None: + """Establish bidirectional connection to Nova Sonic. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + + Raises: + RuntimeError: If user calls start again without first stopping. + """ + if self._connection_id: + raise RuntimeError("model already started | call stop before starting again") + + logger.debug("nova connection starting") + + self._connection_id = str(uuid.uuid4()) + + # Get credentials from boto3 session (full credential chain) + credentials = self._session.get_credentials() + + if not credentials: + raise ValueError( + "no AWS credentials found. configure credentials via environment variables, " + "credential files, IAM roles, or SSO." + ) + + # Use static resolver with credentials configured as properties + resolver = StaticCredentialsResolver() + + config = Config( + endpoint_uri=f"https://bedrock-runtime.{self.region}.amazonaws.com", + region=self.region, + aws_credentials_identity_resolver=resolver, + auth_scheme_resolver=HTTPAuthSchemeResolver(), + auth_schemes={ShapeID("aws.auth#sigv4"): SigV4AuthScheme(service="bedrock")}, + # Configure static credentials as properties + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + aws_session_token=credentials.token, + ) + + self.client = BedrockRuntimeClient(config=config) + logger.debug("region=<%s> | nova sonic client initialized", self.region) + + client = BedrockRuntimeClient(config=config) + self._stream = await client.invoke_model_with_bidirectional_stream( + InvokeModelWithBidirectionalStreamOperationInput(model_id=self.model_id) + ) + logger.debug("region=<%s> | nova sonic client initialized", self.region) + + init_events = self._build_initialization_events(system_prompt, tools, messages) + logger.debug("event_count=<%d> | sending nova sonic initialization events", len(init_events)) + await self._send_nova_events(init_events) + + logger.info("connection_id=<%s> | nova sonic connection established", self._connection_id) + + def _build_initialization_events( + self, system_prompt: str | None, tools: list[ToolSpec] | None, messages: Messages | None + ) -> list[str]: + """Build the sequence of initialization events.""" + tools = tools or [] + events = [ + self._get_connection_start_event(), + self._get_prompt_start_event(tools), + *self._get_system_prompt_events(system_prompt), + ] + + # Add conversation history if provided + if messages: + events.extend(self._get_message_history_events(messages)) + logger.debug("message_count=<%d> | conversation history added to initialization", len(messages)) + + return events + + def _log_event_type(self, nova_event: dict[str, Any]) -> None: + """Log specific Nova Sonic event types for debugging.""" + if "usageEvent" in nova_event: + logger.debug("usage=<%s> | nova usage event received", nova_event["usageEvent"]) + elif "textOutput" in nova_event: + logger.debug("nova text output received") + elif "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + logger.debug( + "tool_name=<%s>, tool_use_id=<%s> | nova tool use received", + tool_use["toolName"], + tool_use["toolUseId"], + ) + elif "audioOutput" in nova_event: + audio_content = nova_event["audioOutput"]["content"] + audio_bytes = base64.b64decode(audio_content) + logger.debug("audio_bytes=<%d> | nova audio output received", len(audio_bytes)) + + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: + """Receive Nova Sonic events and convert to provider-agnostic format. + + Raises: + RuntimeError: If start has not been called. + """ + if not self._connection_id: + raise RuntimeError("model not started | call start before receiving") + + logger.debug("nova event stream starting") + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) + + _, output = await self._stream.await_output() + while True: + try: + event_data = await output.receive() + + except ValidationException as error: + if "InternalErrorCode=531" in error.message: + # nova also times out if user is silent for 175 seconds + raise BidiModelTimeoutError(error.message) from error + raise + + except ModelTimeoutException as error: + raise BidiModelTimeoutError(error.message) from error + + if not event_data: + continue + + nova_event = json.loads(event_data.value.bytes_.decode("utf-8"))["event"] + self._log_event_type(nova_event) + + model_event = self._convert_nova_event(nova_event) + if model_event: + yield model_event + + async def send(self, content: BidiInputEvent | ToolResultEvent) -> None: + """Unified send method for all content types. Sends the given content to Nova Sonic. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Input event. + + Raises: + ValueError: If content type not supported (e.g., image content). + """ + if not self._connection_id: + raise RuntimeError("model not started | call start before sending") + + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + raise ValueError(f"content_type={type(content)} | content not supported") + + async def _start_audio_connection(self) -> None: + """Internal: Start audio input connection (call once before sending audio chunks).""" + logger.debug("nova audio connection starting") + self._audio_content_name = str(uuid.uuid4()) + + # Build audio input configuration from config + audio_input_config = { + "mediaType": "audio/lpcm", + "sampleRateHertz": self.config["audio"]["input_rate"], + "sampleSizeBits": 16, + "channelCount": self.config["audio"]["channels"], + "audioType": "SPEECH", + "encoding": "base64", + } + + audio_content_start = json.dumps( + { + "event": { + "contentStart": { + "promptName": self._connection_id, + "contentName": self._audio_content_name, + "type": "AUDIO", + "interactive": True, + "role": "USER", + "audioInputConfiguration": audio_input_config, + } + } + } + ) + + await self._send_nova_events([audio_content_start]) + + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: + """Internal: Send audio using Nova Sonic protocol-specific format.""" + # Start audio connection if not already active + if not self._audio_content_name: + await self._start_audio_connection() + + # Audio is already base64 encoded in the event + # Send audio input event + audio_event = json.dumps( + { + "event": { + "audioInput": { + "promptName": self._connection_id, + "contentName": self._audio_content_name, + "content": audio_input.audio, + } + } + } + ) + + await self._send_nova_events([audio_event]) + + async def _end_audio_input(self) -> None: + """Internal: End current audio input connection to trigger Nova Sonic processing.""" + if not self._audio_content_name: + return + + logger.debug("nova audio connection ending") + + audio_content_end = json.dumps( + {"event": {"contentEnd": {"promptName": self._connection_id, "contentName": self._audio_content_name}}} + ) + + await self._send_nova_events([audio_content_end]) + self._audio_content_name = None + + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content using Nova Sonic format.""" + content_name = str(uuid.uuid4()) + events = [ + self._get_text_content_start_event(content_name), + self._get_text_input_event(content_name, text), + self._get_content_end_event(content_name), + ] + await self._send_nova_events(events) + + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result using Nova Sonic toolResult format.""" + tool_use_id = tool_result["toolUseId"] + + logger.debug("tool_use_id=<%s> | sending nova tool result", tool_use_id) + + # Validate content types and preserve structure + content = tool_result.get("content", []) + + # Validate all content types are supported + for block in content: + if "text" not in block and "json" not in block: + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " + f"Content type not supported by Nova Sonic" + ) + + # Optimize for single content item - unwrap the array + if len(content) == 1: + result_data = cast(dict[str, Any], content[0]) + else: + # Multiple items - send as array + result_data = {"content": content} + + content_name = str(uuid.uuid4()) + events = [ + self._get_tool_content_start_event(content_name, tool_use_id), + self._get_tool_result_event(content_name, result_data), + self._get_content_end_event(content_name), + ] + await self._send_nova_events(events) + + async def stop(self) -> None: + """Close Nova Sonic connection with proper cleanup sequence.""" + logger.debug("nova connection cleanup starting") + + async def stop_events() -> None: + if not self._connection_id: + return + + await self._end_audio_input() + cleanup_events = [self._get_prompt_end_event(), self._get_connection_end_event()] + await self._send_nova_events(cleanup_events) + + async def stop_stream() -> None: + if not hasattr(self, "_stream"): + return + + await self._stream.close() + + async def stop_connection() -> None: + self._connection_id = None + + await stop_all(stop_events, stop_stream, stop_connection) + + logger.debug("nova connection closed") + + def _convert_nova_event(self, nova_event: dict[str, Any]) -> BidiOutputEvent | None: + """Convert Nova Sonic events to TypedEvent format.""" + # Handle completion start - track completionId + if "completionStart" in nova_event: + completion_data = nova_event["completionStart"] + self._current_completion_id = completion_data.get("completionId") + logger.debug("completion_id=<%s> | nova completion started", self._current_completion_id) + return None + + # Handle completion end + if "completionEnd" in nova_event: + completion_data = nova_event["completionEnd"] + completion_id = completion_data.get("completionId", self._current_completion_id) + stop_reason = completion_data.get("stopReason", "END_TURN") + + event = BidiResponseCompleteEvent( + response_id=completion_id or str(uuid.uuid4()), # Fallback to UUID if missing + stop_reason="interrupted" if stop_reason == "INTERRUPTED" else "complete", + ) + + # Clear completion tracking + self._current_completion_id = None + return event + + # Handle audio output + if "audioOutput" in nova_event: + # Audio is already base64 string from Nova Sonic + audio_content = nova_event["audioOutput"]["content"] + return BidiAudioStreamEvent( + audio=audio_content, + format="pcm", + sample_rate=cast(AudioSampleRate, self.config["audio"]["output_rate"]), + channels=cast(AudioChannel, self.config["audio"]["channels"]), + ) + + # Handle text output (transcripts) + elif "textOutput" in nova_event: + text_output = nova_event["textOutput"] + text_content = text_output["content"] + # Check for Nova Sonic interruption pattern + if '{ "interrupted" : true }' in text_content: + logger.debug("nova interruption detected in text output") + return BidiInterruptionEvent(reason="user_speech") + + return BidiTranscriptStreamEvent( + delta={"text": text_content}, + text=text_content, + role=text_output["role"].lower(), + is_final=self._generation_stage == "FINAL", + current_transcript=text_content, + ) + + # Handle tool use + if "toolUse" in nova_event: + tool_use = nova_event["toolUse"] + tool_use_event: ToolUse = { + "toolUseId": tool_use["toolUseId"], + "name": tool_use["toolName"], + "input": json.loads(tool_use["content"]), + } + # Return ToolUseStreamEvent - cast to dict for type compatibility + return ToolUseStreamEvent(delta={"toolUse": tool_use_event}, current_tool_use=dict(tool_use_event)) + + # Handle interruption + if nova_event.get("stopReason") == "INTERRUPTED": + logger.debug("nova interruption detected via stop reason") + return BidiInterruptionEvent(reason="user_speech") + + # Handle usage events - convert to multimodal usage format + if "usageEvent" in nova_event: + usage_data = nova_event["usageEvent"] + total_input = usage_data.get("totalInputTokens", 0) + total_output = usage_data.get("totalOutputTokens", 0) + + return BidiUsageEvent( + input_tokens=total_input, + output_tokens=total_output, + total_tokens=usage_data.get("totalTokens", total_input + total_output), + ) + + # Handle content start events (emit response start) + if "contentStart" in nova_event: + content_data = nova_event["contentStart"] + if content_data["type"] == "TEXT": + self._generation_stage = json.loads(content_data["additionalModelFields"])["generationStage"] + + # Emit response start event using API-provided completionId + # completionId should already be tracked from completionStart event + return BidiResponseStartEvent( + response_id=self._current_completion_id or str(uuid.uuid4()) # Fallback to UUID if missing + ) + + if "contentEnd" in nova_event: + self._generation_stage = None + + # Ignore all other events + return None + + def _get_connection_start_event(self) -> str: + """Generate Nova Sonic connection start event.""" + return json.dumps({"event": {"sessionStart": {"inferenceConfiguration": NOVA_INFERENCE_CONFIG}}}) + + def _get_prompt_start_event(self, tools: list[ToolSpec]) -> str: + """Generate Nova Sonic prompt start event with tool configuration.""" + # Build audio output configuration from config + audio_output_config = { + "mediaType": "audio/lpcm", + "sampleRateHertz": self.config["audio"]["output_rate"], + "sampleSizeBits": 16, + "channelCount": self.config["audio"]["channels"], + "voiceId": self.config["audio"].get("voice", "matthew"), + "encoding": "base64", + "audioType": "SPEECH", + } + + prompt_start_event: dict[str, Any] = { + "event": { + "promptStart": { + "promptName": self._connection_id, + "textOutputConfiguration": NOVA_TEXT_CONFIG, + "audioOutputConfiguration": audio_output_config, + } + } + } + + if tools: + tool_config = self._build_tool_configuration(tools) + prompt_start_event["event"]["promptStart"]["toolUseOutputConfiguration"] = NOVA_TOOL_CONFIG + prompt_start_event["event"]["promptStart"]["toolConfiguration"] = {"tools": tool_config} + + return json.dumps(prompt_start_event) + + def _build_tool_configuration(self, tools: list[ToolSpec]) -> list[dict[str, Any]]: + """Build tool configuration from tool specs.""" + tool_config: list[dict[str, Any]] = [] + for tool in tools: + input_schema = ( + {"json": json.dumps(tool["inputSchema"]["json"])} + if "json" in tool["inputSchema"] + else {"json": json.dumps(tool["inputSchema"])} + ) + + tool_config.append( + {"toolSpec": {"name": tool["name"], "description": tool["description"], "inputSchema": input_schema}} + ) + return tool_config + + def _get_system_prompt_events(self, system_prompt: str | None) -> list[str]: + """Generate system prompt events.""" + content_name = str(uuid.uuid4()) + return [ + self._get_text_content_start_event(content_name, "SYSTEM"), + self._get_text_input_event(content_name, system_prompt or ""), + self._get_content_end_event(content_name), + ] + + def _get_message_history_events(self, messages: Messages) -> list[str]: + """Generate conversation history events from agent messages. + + Converts agent message history to Nova Sonic format following the + contentStart/textInput/contentEnd pattern for each message. + + Args: + messages: List of conversation messages with role and content. + + Returns: + List of JSON event strings for Nova Sonic. + """ + events = [] + + for message in messages: + role = message["role"].upper() # Convert to ASSISTANT or USER + content_blocks = message.get("content", []) + + # Extract text content from content blocks + text_parts = [] + for block in content_blocks: + if "text" in block: + text_parts.append(block["text"]) + + # Combine all text parts + if text_parts: + combined_text = "\n".join(text_parts) + content_name = str(uuid.uuid4()) + + # Add contentStart, textInput, and contentEnd events + events.extend( + [ + self._get_text_content_start_event(content_name, role), + self._get_text_input_event(content_name, combined_text), + self._get_content_end_event(content_name), + ] + ) + + return events + + def _get_text_content_start_event(self, content_name: str, role: str = "USER") -> str: + """Generate text content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self._connection_id, + "contentName": content_name, + "type": "TEXT", + "role": role, + "interactive": True, + "textInputConfiguration": NOVA_TEXT_CONFIG, + } + } + } + ) + + def _get_tool_content_start_event(self, content_name: str, tool_use_id: str) -> str: + """Generate tool content start event.""" + return json.dumps( + { + "event": { + "contentStart": { + "promptName": self._connection_id, + "contentName": content_name, + "interactive": False, + "type": "TOOL", + "role": "TOOL", + "toolResultInputConfiguration": { + "toolUseId": tool_use_id, + "type": "TEXT", + "textInputConfiguration": NOVA_TEXT_CONFIG, + }, + } + } + } + ) + + def _get_text_input_event(self, content_name: str, text: str) -> str: + """Generate text input event.""" + return json.dumps( + {"event": {"textInput": {"promptName": self._connection_id, "contentName": content_name, "content": text}}} + ) + + def _get_tool_result_event(self, content_name: str, result: dict[str, Any]) -> str: + """Generate tool result event.""" + return json.dumps( + { + "event": { + "toolResult": { + "promptName": self._connection_id, + "contentName": content_name, + "content": json.dumps(result), + } + } + } + ) + + def _get_content_end_event(self, content_name: str) -> str: + """Generate content end event.""" + return json.dumps({"event": {"contentEnd": {"promptName": self._connection_id, "contentName": content_name}}}) + + def _get_prompt_end_event(self) -> str: + """Generate prompt end event.""" + return json.dumps({"event": {"promptEnd": {"promptName": self._connection_id}}}) + + def _get_connection_end_event(self) -> str: + """Generate connection end event.""" + return json.dumps({"event": {"connectionEnd": {}}}) + + async def _send_nova_events(self, events: list[str]) -> None: + """Send event JSON string to Nova Sonic stream. + + A lock is used to send events in sequence when required (e.g., tool result start, content, and end). + + Args: + events: Jsonified events. + """ + async with self._send_lock: + for event in events: + bytes_data = event.encode("utf-8") + chunk = InvokeModelWithBidirectionalStreamInputChunk( + value=BidirectionalInputPayloadPart(bytes_=bytes_data) + ) + await self._stream.input_stream.send(chunk) + logger.debug("nova sonic event sent successfully") diff --git a/src/strands/experimental/bidi/models/openai_realtime.py b/src/strands/experimental/bidi/models/openai_realtime.py new file mode 100644 index 000000000..79ef5f78c --- /dev/null +++ b/src/strands/experimental/bidi/models/openai_realtime.py @@ -0,0 +1,816 @@ +"""OpenAI Realtime API provider for Strands bidirectional streaming. + +Provides real-time audio and text communication through OpenAI's Realtime API +with WebSocket connections, voice activity detection, and function calling. +""" + +import asyncio +import json +import logging +import os +import time +import uuid +from typing import Any, AsyncGenerator, Literal, cast + +import websockets +from websockets import ClientConnection + +from ....types._events import ToolResultEvent, ToolUseStreamEvent +from ....types.content import Messages +from ....types.tools import ToolResult, ToolSpec, ToolUse +from .._async import stop_all +from ..types.bidi_model import AudioConfig +from ..types.events import ( + AudioSampleRate, + BidiAudioInputEvent, + BidiAudioStreamEvent, + BidiConnectionStartEvent, + BidiInputEvent, + BidiInterruptionEvent, + BidiOutputEvent, + BidiResponseCompleteEvent, + BidiResponseStartEvent, + BidiTextInputEvent, + BidiTranscriptStreamEvent, + BidiUsageEvent, + ModalityUsage, + Role, + StopReason, +) +from .model import BidiModel, BidiModelTimeoutError + +logger = logging.getLogger(__name__) + +# Test idle_timeout_ms + +# OpenAI Realtime API configuration +OPENAI_MAX_TIMEOUT_S = 3000 # 50 minutes +"""Max timeout before closing connection. + +OpenAI documents a 60 minute limit on realtime sessions +(https://platform.openai.com/docs/guides/realtime-conversations#session-lifecycle-events). However, OpenAI does not +emit any warnings when approaching the limit. As a workaround, we configure a max timeout client side to gracefully +handle the connection closure. We set the max to 50 minutes to provide enough buffer before hitting the real limit. +""" +OPENAI_REALTIME_URL = "wss://api.openai.com/v1/realtime" +DEFAULT_MODEL = "gpt-realtime" +DEFAULT_SAMPLE_RATE = 24000 + +DEFAULT_SESSION_CONFIG = { + "type": "realtime", + "instructions": "You are a helpful assistant. Please speak in English and keep your responses clear and concise.", + "output_modalities": ["audio"], + "audio": { + "input": { + "format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, + "transcription": {"model": "gpt-4o-transcribe"}, + "turn_detection": { + "type": "server_vad", + "threshold": 0.5, + "prefix_padding_ms": 300, + "silence_duration_ms": 500, + }, + }, + "output": {"format": {"type": "audio/pcm", "rate": DEFAULT_SAMPLE_RATE}, "voice": "alloy"}, + }, +} + + +class BidiOpenAIRealtimeModel(BidiModel): + """OpenAI Realtime API implementation for bidirectional streaming. + + Combines model configuration and connection state in a single class. + Manages WebSocket connection to OpenAI's Realtime API with automatic VAD, + function calling, and event conversion to Strands format. + """ + + _websocket: ClientConnection + _start_time: int + + def __init__( + self, + model_id: str = DEFAULT_MODEL, + provider_config: dict[str, Any] | None = None, + client_config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + """Initialize OpenAI Realtime bidirectional model. + + Args: + model_id: Model identifier (default: gpt-realtime) + provider_config: Model behavior (audio, instructions, turn_detection, etc.) + client_config: Authentication (api_key, organization, project) + Falls back to OPENAI_API_KEY, OPENAI_ORGANIZATION, OPENAI_PROJECT env vars + **kwargs: Reserved for future parameters. + + """ + # Store model ID + self.model_id = model_id + + # Resolve client config with defaults and env vars + self._client_config = self._resolve_client_config(client_config or {}) + + # Resolve provider config with defaults + self.config = self._resolve_provider_config(provider_config or {}) + + # Store client config values for later use + self.api_key = self._client_config["api_key"] + self.organization = self._client_config.get("organization") + self.project = self._client_config.get("project") + self.timeout_s = self._client_config["timeout_s"] + + if self.timeout_s > OPENAI_MAX_TIMEOUT_S: + raise ValueError( + f"timeout_s=<{self.timeout_s}>, max_timeout_s=<{OPENAI_MAX_TIMEOUT_S}> | timeout exceeds max limit" + ) + + # Connection state (initialized in start()) + self._connection_id: str | None = None + + self._function_call_buffer: dict[str, Any] = {} + + logger.debug("model=<%s> | openai realtime model initialized", model_id) + + def _resolve_client_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Resolve client config with env var fallback (config takes precedence).""" + resolved = config.copy() + + if "api_key" not in resolved: + resolved["api_key"] = os.getenv("OPENAI_API_KEY") + + if not resolved.get("api_key"): + raise ValueError( + "OpenAI API key is required. Provide via client_config={'api_key': '...'} " + "or set OPENAI_API_KEY environment variable." + ) + if "organization" not in resolved: + env_org = os.getenv("OPENAI_ORGANIZATION") + if env_org: + resolved["organization"] = env_org + + if "project" not in resolved: + env_project = os.getenv("OPENAI_PROJECT") + if env_project: + resolved["project"] = env_project + + if "timeout_s" not in resolved: + resolved["timeout_s"] = OPENAI_MAX_TIMEOUT_S + + return resolved + + def _resolve_provider_config(self, config: dict[str, Any]) -> dict[str, Any]: + """Merge user config with defaults (user takes precedence).""" + # Extract voice from provider-specific audio.output.voice if present + provider_voice = None + if "audio" in config and isinstance(config["audio"], dict): + if "output" in config["audio"] and isinstance(config["audio"]["output"], dict): + provider_voice = config["audio"]["output"].get("voice") + + # Define default audio configuration + default_audio: AudioConfig = { + "input_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), + "output_rate": cast(AudioSampleRate, DEFAULT_SAMPLE_RATE), + "channels": 1, + "format": "pcm", + "voice": provider_voice or "alloy", + } + + user_audio = config.get("audio", {}) + merged_audio = {**default_audio, **user_audio} + + resolved = { + "audio": merged_audio, + **{k: v for k, v in config.items() if k != "audio"}, + } + + if user_audio: + logger.debug("audio_config | merged user-provided config with defaults") + else: + logger.debug("audio_config | using default OpenAI Realtime audio configuration") + + return resolved + + async def start( + self, + system_prompt: str | None = None, + tools: list[ToolSpec] | None = None, + messages: Messages | None = None, + **kwargs: Any, + ) -> None: + """Establish bidirectional connection to OpenAI Realtime API. + + Args: + system_prompt: System instructions for the model. + tools: List of tools available to the model. + messages: Conversation history to initialize with. + **kwargs: Additional configuration options. + """ + if self._connection_id: + raise RuntimeError("model already started | call stop before starting again") + + logger.debug("openai realtime connection starting") + + # Initialize connection state + self._connection_id = str(uuid.uuid4()) + self._start_time = int(time.time()) + + self._function_call_buffer = {} + + # Establish WebSocket connection + url = f"{OPENAI_REALTIME_URL}?model={self.model_id}" + + headers = [("Authorization", f"Bearer {self.api_key}")] + if self.organization: + headers.append(("OpenAI-Organization", self.organization)) + if self.project: + headers.append(("OpenAI-Project", self.project)) + + self._websocket = await websockets.connect(url, additional_headers=headers) + logger.debug("connection_id=<%s> | websocket connected successfully", self._connection_id) + + # Configure session + session_config = self._build_session_config(system_prompt, tools) + await self._send_event({"type": "session.update", "session": session_config}) + + # Add conversation history if provided + if messages: + await self._add_conversation_history(messages) + + def _create_text_event(self, text: str, role: str, is_final: bool = True) -> BidiTranscriptStreamEvent: + """Create standardized transcript event. + + Args: + text: The transcript text + role: The role (will be normalized to lowercase) + is_final: Whether this is the final transcript + """ + # Normalize role to lowercase and ensure it's either "user" or "assistant" + normalized_role = role.lower() if isinstance(role, str) else "assistant" + if normalized_role not in ["user", "assistant"]: + normalized_role = "assistant" + + return BidiTranscriptStreamEvent( + delta={"text": text}, + text=text, + role=cast(Role, normalized_role), + is_final=is_final, + current_transcript=text if is_final else None, + ) + + def _create_voice_activity_event(self, activity_type: str) -> BidiInterruptionEvent | None: + """Create standardized interruption event for voice activity.""" + # Only speech_started triggers interruption + if activity_type == "speech_started": + return BidiInterruptionEvent(reason="user_speech") + # Other voice activity events are logged but don't create events + return None + + def _build_session_config(self, system_prompt: str | None, tools: list[ToolSpec] | None) -> dict[str, Any]: + """Build session configuration for OpenAI Realtime API.""" + config: dict[str, Any] = DEFAULT_SESSION_CONFIG.copy() + + if system_prompt: + config["instructions"] = system_prompt + + if tools: + config["tools"] = self._convert_tools_to_openai_format(tools) + + # Apply user-provided session configuration + supported_params = { + "type", + "output_modalities", + "instructions", + "voice", + "tools", + "tool_choice", + "input_audio_format", + "output_audio_format", + "input_audio_transcription", + "turn_detection", + } + + for key, value in self.config.items(): + if key == "audio": + continue + elif key in supported_params: + config[key] = value + else: + logger.warning("parameter=<%s> | ignoring unsupported session parameter", key) + + audio_config = self.config["audio"] + + if "voice" in audio_config: + config.setdefault("audio", {}).setdefault("output", {})["voice"] = audio_config["voice"] + + if "input_rate" in audio_config: + config.setdefault("audio", {}).setdefault("input", {}).setdefault("format", {})["rate"] = audio_config[ + "input_rate" + ] + + if "output_rate" in audio_config: + config.setdefault("audio", {}).setdefault("output", {}).setdefault("format", {})["rate"] = audio_config[ + "output_rate" + ] + + return config + + def _convert_tools_to_openai_format(self, tools: list[ToolSpec]) -> list[dict]: + """Convert Strands tool specifications to OpenAI Realtime API format.""" + openai_tools = [] + + for tool in tools: + input_schema = tool["inputSchema"] + if "json" in input_schema: + schema = ( + json.loads(input_schema["json"]) if isinstance(input_schema["json"], str) else input_schema["json"] + ) + else: + schema = input_schema + + # OpenAI Realtime API expects flat structure, not nested under "function" + openai_tool = { + "type": "function", + "name": tool["name"], + "description": tool["description"], + "parameters": schema, + } + openai_tools.append(openai_tool) + + return openai_tools + + async def _add_conversation_history(self, messages: Messages) -> None: + """Add conversation history to the session. + + Converts agent message history to OpenAI Realtime API format using + conversation.item.create events for each message. + + Note: OpenAI Realtime API has a 32-character limit on call_id, so we truncate + UUIDs consistently to ensure tool calls and their results match. + + Args: + messages: List of conversation messages with role and content. + """ + # Track tool call IDs to ensure consistency between calls and results + call_id_map: dict[str, str] = {} + + # First pass: collect all tool call IDs + for message in messages: + for block in message.get("content", []): + if "toolUse" in block: + tool_use = block["toolUse"] + original_id = tool_use["toolUseId"] + call_id = original_id[:32] + call_id_map[original_id] = call_id + + # Second pass: send messages + for message in messages: + role = message["role"] + content_blocks = message.get("content", []) + + # Build content array for OpenAI format + openai_content = [] + + for block in content_blocks: + if "text" in block: + # Text content - use appropriate type based on role + # User messages use "input_text", assistant messages use "output_text" + if role == "user": + openai_content.append({"type": "input_text", "text": block["text"]}) + else: # assistant + openai_content.append({"type": "output_text", "text": block["text"]}) + elif "toolUse" in block: + # Tool use - create as function_call item + tool_use = block["toolUse"] + original_id = tool_use["toolUseId"] + # Use pre-mapped call_id + call_id = call_id_map[original_id] + + tool_item = { + "type": "conversation.item.create", + "item": { + "type": "function_call", + "call_id": call_id, + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]), + }, + } + await self._send_event(tool_item) + continue # Tool use is sent separately, not in message content + elif "toolResult" in block: + # Tool result - create as function_call_output item + tool_result = block["toolResult"] + original_id = tool_result["toolUseId"] + + # Validate content types and serialize, preserving structure + result_output = "" + if "content" in tool_result: + # First validate all content types are supported + for result_block in tool_result["content"]: + if "text" not in result_block and "json" not in result_block: + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{original_id}>, content_types=<{list(result_block.keys())}> | " + f"Content type not supported by OpenAI Realtime API" + ) + + # Preserve structure by JSON-dumping the entire content array + result_output = json.dumps(tool_result["content"]) + + # Use mapped call_id if available, otherwise skip orphaned result + if original_id not in call_id_map: + continue # Skip this tool result since we don't have the call + + call_id = call_id_map[original_id] + + result_item = { + "type": "conversation.item.create", + "item": { + "type": "function_call_output", + "call_id": call_id, + "output": result_output, + }, + } + await self._send_event(result_item) + continue # Tool result is sent separately, not in message content + + # Only create message item if there's text content + if openai_content: + conversation_item = { + "type": "conversation.item.create", + "item": {"type": "message", "role": role, "content": openai_content}, + } + await self._send_event(conversation_item) + + logger.debug("message_count=<%d> | conversation history added to openai session", len(messages)) + + async def receive(self) -> AsyncGenerator[BidiOutputEvent, None]: + """Receive OpenAI events and convert to Strands TypedEvent format.""" + if not self._connection_id: + raise RuntimeError("model not started | call start before sending/receiving") + + yield BidiConnectionStartEvent(connection_id=self._connection_id, model=self.model_id) + + while True: + duration = time.time() - self._start_time + if duration >= self.timeout_s: + raise BidiModelTimeoutError(f"timeout_s=<{self.timeout_s}>") + + try: + message = await asyncio.wait_for(self._websocket.recv(), timeout=10) + except asyncio.TimeoutError: + continue + + openai_event = json.loads(message) + + for event in self._convert_openai_event(openai_event) or []: + yield event + + def _convert_openai_event(self, openai_event: dict[str, Any]) -> list[BidiOutputEvent] | None: + """Convert OpenAI events to Strands TypedEvent format.""" + event_type = openai_event.get("type") + + # Turn start - response begins + if event_type == "response.created": + response = openai_event.get("response", {}) + response_id = response.get("id", str(uuid.uuid4())) + return [BidiResponseStartEvent(response_id=response_id)] + + # Audio output + elif event_type == "response.output_audio.delta": + # Audio is already base64 string from OpenAI + # Use the resolved output sample rate from our merged configuration + sample_rate = self.config["audio"]["output_rate"] + + # Channels from config is guaranteed to be 1 or 2 + channels = cast(Literal[1, 2], self.config["audio"]["channels"]) + return [ + BidiAudioStreamEvent( + audio=openai_event["delta"], + format="pcm", + sample_rate=sample_rate, + channels=channels, + ) + ] + + # Assistant text output events - combine multiple similar events + elif event_type in ["response.output_text.delta", "response.output_audio_transcript.delta"]: + role = openai_event.get("role", "assistant") + return [ + self._create_text_event( + openai_event["delta"], role.lower() if isinstance(role, str) else "assistant", is_final=False + ) + ] + + elif event_type in ["response.output_audio_transcript.done"]: + role = openai_event.get("role", "assistant").lower() + return [self._create_text_event(openai_event["transcript"], role)] + + elif event_type in ["response.output_text.done"]: + role = openai_event.get("role", "assistant").lower() + return [self._create_text_event(openai_event["text"], role)] + + # User transcription events - combine multiple similar events + elif event_type in [ + "conversation.item.input_audio_transcription.delta", + "conversation.item.input_audio_transcription.completed", + ]: + text_key = "delta" if "delta" in event_type else "transcript" + text = openai_event.get(text_key, "") + role = openai_event.get("role", "user") + is_final = "completed" in event_type + return ( + [self._create_text_event(text, role.lower() if isinstance(role, str) else "user", is_final=is_final)] + if text.strip() + else None + ) + + elif event_type == "conversation.item.input_audio_transcription.segment": + segment_data = openai_event.get("segment", {}) + text = segment_data.get("text", "") + role = segment_data.get("role", "user") + return ( + [self._create_text_event(text, role.lower() if isinstance(role, str) else "user")] + if text.strip() + else None + ) + + elif event_type == "conversation.item.input_audio_transcription.failed": + error_info = openai_event.get("error", {}) + logger.warning("error=<%s> | openai transcription failed", error_info.get("message", "unknown error")) + return None + + # Function call processing + elif event_type == "response.function_call_arguments.delta": + call_id = openai_event.get("call_id") + delta = openai_event.get("delta", "") + if call_id: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = {"call_id": call_id, "name": "", "arguments": delta} + else: + self._function_call_buffer[call_id]["arguments"] += delta + return None + + elif event_type == "response.function_call_arguments.done": + call_id = openai_event.get("call_id") + if call_id and call_id in self._function_call_buffer: + function_call = self._function_call_buffer[call_id] + try: + tool_use: ToolUse = { + "toolUseId": call_id, + "name": function_call["name"], + "input": json.loads(function_call["arguments"]) if function_call["arguments"] else {}, + } + del self._function_call_buffer[call_id] + # Return ToolUseStreamEvent for consistency with standard agent + return [ToolUseStreamEvent(delta={"toolUse": tool_use}, current_tool_use=dict(tool_use))] + except (json.JSONDecodeError, KeyError) as e: + logger.warning("call_id=<%s>, error=<%s> | error parsing function arguments", call_id, e) + del self._function_call_buffer[call_id] + return None + + # Voice activity detection - speech_started triggers interruption + elif event_type == "input_audio_buffer.speech_started": + # This is the primary interruption signal - handle it first + return [BidiInterruptionEvent(reason="user_speech")] + + # Response cancelled - handle interruption + elif event_type == "response.cancelled": + response = openai_event.get("response", {}) + response_id = response.get("id", "unknown") + logger.debug("response_id=<%s> | openai response cancelled", response_id) + return [BidiResponseCompleteEvent(response_id=response_id, stop_reason="interrupted")] + + # Turn complete and usage - response finished + elif event_type == "response.done": + response = openai_event.get("response", {}) + response_id = response.get("id", "unknown") + status = response.get("status", "completed") + usage = response.get("usage") + + # Map OpenAI status to our stop_reason + stop_reason_map = { + "completed": "complete", + "cancelled": "interrupted", + "failed": "error", + "incomplete": "interrupted", + } + + # Build list of events to return + events: list[Any] = [] + + # Always add response complete event + events.append( + BidiResponseCompleteEvent( + response_id=response_id, + stop_reason=cast(StopReason, stop_reason_map.get(status, "complete")), + ), + ) + + # Add usage event if available + if usage: + input_details = usage.get("input_token_details", {}) + output_details = usage.get("output_token_details", {}) + + # Build modality details + modality_details = [] + + # Text modality + text_input = input_details.get("text_tokens", 0) + text_output = output_details.get("text_tokens", 0) + if text_input > 0 or text_output > 0: + modality_details.append( + {"modality": "text", "input_tokens": text_input, "output_tokens": text_output} + ) + + # Audio modality + audio_input = input_details.get("audio_tokens", 0) + audio_output = output_details.get("audio_tokens", 0) + if audio_input > 0 or audio_output > 0: + modality_details.append( + {"modality": "audio", "input_tokens": audio_input, "output_tokens": audio_output} + ) + + # Image modality + image_input = input_details.get("image_tokens", 0) + if image_input > 0: + modality_details.append({"modality": "image", "input_tokens": image_input, "output_tokens": 0}) + + # Cached tokens + cached_tokens = input_details.get("cached_tokens", 0) + + # Add usage event + events.append( + BidiUsageEvent( + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + total_tokens=usage.get("total_tokens", 0), + modality_details=cast(list[ModalityUsage], modality_details) if modality_details else None, + cache_read_input_tokens=cached_tokens if cached_tokens > 0 else None, + ) + ) + + # Return list of events + return events + + # Lifecycle events (log only) - combine multiple similar events + elif event_type in ["conversation.item.retrieve", "conversation.item.added"]: + item = openai_event.get("item", {}) + action = "retrieved" if "retrieve" in event_type else "added" + logger.debug("action=<%s>, item_id=<%s> | openai conversation item event", action, item.get("id")) + return None + + elif event_type == "conversation.item.done": + logger.debug("item_id=<%s> | openai conversation item done", openai_event.get("item", {}).get("id")) + return None + + # Response output events - combine similar events + elif event_type in [ + "response.output_item.added", + "response.output_item.done", + "response.content_part.added", + "response.content_part.done", + ]: + item_data = openai_event.get("item") or openai_event.get("part") + logger.debug( + "event_type=<%s>, item_id=<%s> | openai output event", + event_type, + item_data.get("id") if item_data else "unknown", + ) + + # Track function call names from response.output_item.added + if event_type == "response.output_item.added": + item = openai_event.get("item", {}) + if item.get("type") == "function_call": + call_id = item.get("call_id") + function_name = item.get("name") + if call_id and function_name: + if call_id not in self._function_call_buffer: + self._function_call_buffer[call_id] = { + "call_id": call_id, + "name": function_name, + "arguments": "", + } + else: + self._function_call_buffer[call_id]["name"] = function_name + return None + + # Session/buffer events - combine simple log-only events + elif event_type in [ + "input_audio_buffer.committed", + "input_audio_buffer.cleared", + "session.created", + "session.updated", + ]: + logger.debug("event_type=<%s> | openai event received", event_type) + return None + + elif event_type == "error": + error_data = openai_event.get("error", {}) + error_code = error_data.get("code", "") + + # Suppress expected errors that don't affect session state + if error_code == "response_cancel_not_active": + # This happens when trying to cancel a response that's not active + # It's safe to ignore as the session remains functional + logger.debug("openai response cancel attempted when no response active") + return None + + # Log other errors + logger.error("error=<%s> | openai realtime error", error_data) + return None + + else: + logger.debug("event_type=<%s> | unhandled openai event type", event_type) + return None + + async def send( + self, + content: BidiInputEvent | ToolResultEvent, + ) -> None: + """Unified send method for all content types. Sends the given content to OpenAI. + + Dispatches to appropriate internal handler based on content type. + + Args: + content: Typed event (BidiTextInputEvent, BidiAudioInputEvent, BidiImageInputEvent, or ToolResultEvent). + + Raises: + ValueError: If content type not supported (e.g., image content). + """ + if not self._connection_id: + raise RuntimeError("model not started | call start before sending") + + # Note: TypedEvent inherits from dict, so isinstance checks for TypedEvent must come first + if isinstance(content, BidiTextInputEvent): + await self._send_text_content(content.text) + elif isinstance(content, BidiAudioInputEvent): + await self._send_audio_content(content) + elif isinstance(content, ToolResultEvent): + tool_result = content.get("tool_result") + if tool_result: + await self._send_tool_result(tool_result) + else: + raise ValueError(f"content_type={type(content)} | content not supported") + + async def _send_audio_content(self, audio_input: BidiAudioInputEvent) -> None: + """Internal: Send audio content to OpenAI for processing.""" + # Audio is already base64 encoded in the event + await self._send_event({"type": "input_audio_buffer.append", "audio": audio_input.audio}) + + async def _send_text_content(self, text: str) -> None: + """Internal: Send text content to OpenAI for processing.""" + item_data = {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]} + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def _send_interrupt(self) -> None: + """Internal: Send interruption signal to OpenAI.""" + await self._send_event({"type": "response.cancel"}) + + async def _send_tool_result(self, tool_result: ToolResult) -> None: + """Internal: Send tool result back to OpenAI.""" + tool_use_id = tool_result.get("toolUseId") + + logger.debug("tool_use_id=<%s> | sending openai tool result", tool_use_id) + + # Validate content types and serialize, preserving structure + result_output = "" + if "content" in tool_result: + # First validate all content types are supported + for block in tool_result["content"]: + if "text" not in block and "json" not in block: + # Unsupported content type - raise error + raise ValueError( + f"tool_use_id=<{tool_use_id}>, content_types=<{list(block.keys())}> | " + f"Content type not supported by OpenAI Realtime API" + ) + + # Preserve structure by JSON-dumping the entire content array + result_output = json.dumps(tool_result["content"]) + + item_data = {"type": "function_call_output", "call_id": tool_use_id, "output": result_output} + await self._send_event({"type": "conversation.item.create", "item": item_data}) + await self._send_event({"type": "response.create"}) + + async def stop(self) -> None: + """Close session and cleanup resources.""" + logger.debug("openai realtime connection cleanup starting") + + async def stop_websocket() -> None: + if not hasattr(self, "_websocket"): + return + + await self._websocket.close() + + async def stop_connection() -> None: + self._connection_id = None + + await stop_all(stop_websocket, stop_connection) + + logger.debug("openai realtime connection closed") + + async def _send_event(self, event: dict[str, Any]) -> None: + """Send event to OpenAI via WebSocket.""" + message = json.dumps(event) + await self._websocket.send(message) + logger.debug("event_type=<%s> | openai event sent", event.get("type")) diff --git a/src/strands/experimental/bidi/types/model.py b/src/strands/experimental/bidi/types/model.py new file mode 100644 index 000000000..de41de1a9 --- /dev/null +++ b/src/strands/experimental/bidi/types/model.py @@ -0,0 +1,36 @@ +"""Model-related type definitions for bidirectional streaming. + +Defines types and configurations that are central to model providers, +including audio configuration that models use to specify their audio +processing requirements. +""" + +from typing import TypedDict + +from .events import AudioChannel, AudioFormat, AudioSampleRate + + +class AudioConfig(TypedDict, total=False): + """Audio configuration for bidirectional streaming models. + + Defines standard audio parameters that model providers use to specify + their audio processing requirements. All fields are optional to support + models that may not use audio or only need specific parameters. + + Model providers build this configuration by merging user-provided values + with their own defaults. The resulting configuration is then used by + audio I/O implementations to configure hardware appropriately. + + Attributes: + input_rate: Input sample rate in Hz (e.g., 16000, 24000, 48000) + output_rate: Output sample rate in Hz (e.g., 16000, 24000, 48000) + channels: Number of audio channels (1=mono, 2=stereo) + format: Audio encoding format + voice: Voice identifier for text-to-speech (e.g., "alloy", "matthew") + """ + + input_rate: AudioSampleRate + output_rate: AudioSampleRate + channels: AudioChannel + format: AudioFormat + voice: str From e472b929e98a2a3cd5dca9cb3e77cb144b29a8fd Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 18:06:10 -0500 Subject: [PATCH 3/6] addressed comments --- src/strands/experimental/bidi/__init__.py | 6 +++--- src/strands/experimental/bidi/models/__init__.py | 2 +- src/strands/experimental/bidi/models/gemini_live.py | 4 +--- src/strands/experimental/bidi/models/nova_sonic.py | 3 +-- src/strands/experimental/bidi/models/openai_realtime.py | 2 +- .../bidi/models/{test_novasonic.py => test_nova_sonic.py} | 0 .../bidi/models/{test_openai.py => test_openai_realtime.py} | 0 7 files changed, 7 insertions(+), 10 deletions(-) rename tests/strands/experimental/bidi/models/{test_novasonic.py => test_nova_sonic.py} (100%) rename tests/strands/experimental/bidi/models/{test_openai.py => test_openai_realtime.py} (100%) diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index 13c5b51e1..d274bfbcb 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -17,11 +17,11 @@ # IO channels - Hardware abstraction from .io.audio import BidiAudioIO -# Model interface (for custom implementations) -from .models.model import BidiModel - # Model providers - What users need to create models from .models.gemini_live import BidiGeminiLiveModel + +# Model interface (for custom implementations) +from .models.model import BidiModel from .models.nova_sonic import BidiNovaSonicModel from .models.openai_realtime import BidiOpenAIRealtimeModel diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index 29a2229c5..b56208c1e 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,7 +1,7 @@ """Bidirectional model interfaces and implementations.""" -from .model import BidiModel, BidiModelTimeoutError from .gemini_live import BidiGeminiLiveModel +from .model import BidiModel, BidiModelTimeoutError from .nova_sonic import BidiNovaSonicModel from .openai_realtime import BidiOpenAIRealtimeModel diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index 4f7a9db44..a267211d1 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -4,7 +4,6 @@ official Google GenAI SDK for simplified and robust WebSocket communication. Key improvements over custom WebSocket implementation: - - Uses official google-genai SDK with native Live API support - Simplified session management with client.aio.live.connect() - Built-in tool integration and event handling @@ -25,7 +24,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.bidi_model import AudioConfig +from ..types.model import AudioConfig from ..types.events import ( AudioChannel, AudioSampleRate, @@ -222,7 +221,6 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut """Convert Gemini Live API events to provider-agnostic format. Handles different types of content: - - inputTranscription: User's speech transcribed to text - outputTranscription: Model's audio transcribed to text - modelTurn text: Text response from the model diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 04037a90f..9ccc3d58f 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -5,7 +5,6 @@ InvokeModelWithBidirectionalStream protocol. Nova Sonic specifics: - - Hierarchical event sequences: connectionStart → promptStart → content streaming - Base64-encoded audio format with hex encoding - Tool execution with content containers and identifier tracking @@ -37,7 +36,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.bidi_model import AudioConfig +from ..types.model import AudioConfig from ..types.events import ( AudioChannel, AudioSampleRate, diff --git a/src/strands/experimental/bidi/models/openai_realtime.py b/src/strands/experimental/bidi/models/openai_realtime.py index 9a4584365..39312c7d3 100644 --- a/src/strands/experimental/bidi/models/openai_realtime.py +++ b/src/strands/experimental/bidi/models/openai_realtime.py @@ -19,7 +19,7 @@ from ....types.content import Messages from ....types.tools import ToolResult, ToolSpec, ToolUse from .._async import stop_all -from ..types.bidi_model import AudioConfig +from ..types.model import AudioConfig from ..types.events import ( AudioSampleRate, BidiAudioInputEvent, diff --git a/tests/strands/experimental/bidi/models/test_novasonic.py b/tests/strands/experimental/bidi/models/test_nova_sonic.py similarity index 100% rename from tests/strands/experimental/bidi/models/test_novasonic.py rename to tests/strands/experimental/bidi/models/test_nova_sonic.py diff --git a/tests/strands/experimental/bidi/models/test_openai.py b/tests/strands/experimental/bidi/models/test_openai_realtime.py similarity index 100% rename from tests/strands/experimental/bidi/models/test_openai.py rename to tests/strands/experimental/bidi/models/test_openai_realtime.py From 75a77751a2458e75afbb9f52b765c77698a022d2 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 18:23:23 -0500 Subject: [PATCH 4/6] addrsssed comments --- pyproject.toml | 2 ++ scripts/bidi/test_bidi_openai.py | 2 +- src/strands/experimental/bidi/__init__.py | 6 ------ src/strands/experimental/bidi/models/__init__.py | 4 ---- src/strands/experimental/bidi/models/gemini_live.py | 1 + src/strands/experimental/bidi/models/nova_sonic.py | 1 + 6 files changed, 5 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2a8b250fe..944a1b3a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,8 @@ bidi = [ "prompt_toolkit>=3.0.0,<4.0.0", "pyaudio>=0.2.13,<1.0.0", "smithy-aws-core>=0.0.1; python_version>='3.12'", + "google-genai>=1.32.0,<2.0.0", + "websockets>=15.0.0,<16.0.0", ] bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] bidi-openai = ["websockets>=15.0.0,<16.0.0"] diff --git a/scripts/bidi/test_bidi_openai.py b/scripts/bidi/test_bidi_openai.py index 50d2d2f55..677c12981 100644 --- a/scripts/bidi/test_bidi_openai.py +++ b/scripts/bidi/test_bidi_openai.py @@ -10,7 +10,7 @@ from strands_tools import calculator from strands.experimental.bidi.agent.agent import BidiAgent -from strands.experimental.bidi.models.openai import BidiOpenAIRealtimeModel +from strands.experimental.bidi.models.openai_realtime import BidiOpenAIRealtimeModel async def play(context): diff --git a/src/strands/experimental/bidi/__init__.py b/src/strands/experimental/bidi/__init__.py index d274bfbcb..57986062e 100644 --- a/src/strands/experimental/bidi/__init__.py +++ b/src/strands/experimental/bidi/__init__.py @@ -17,13 +17,9 @@ # IO channels - Hardware abstraction from .io.audio import BidiAudioIO -# Model providers - What users need to create models -from .models.gemini_live import BidiGeminiLiveModel - # Model interface (for custom implementations) from .models.model import BidiModel from .models.nova_sonic import BidiNovaSonicModel -from .models.openai_realtime import BidiOpenAIRealtimeModel # Built-in tools from .tools import stop_conversation @@ -53,9 +49,7 @@ # IO channels "BidiAudioIO", # Model providers - "BidiGeminiLiveModel", "BidiNovaSonicModel", - "BidiOpenAIRealtimeModel", # Built-in tools "stop_conversation", # Input Event types diff --git a/src/strands/experimental/bidi/models/__init__.py b/src/strands/experimental/bidi/models/__init__.py index b56208c1e..cc62c9987 100644 --- a/src/strands/experimental/bidi/models/__init__.py +++ b/src/strands/experimental/bidi/models/__init__.py @@ -1,14 +1,10 @@ """Bidirectional model interfaces and implementations.""" -from .gemini_live import BidiGeminiLiveModel from .model import BidiModel, BidiModelTimeoutError from .nova_sonic import BidiNovaSonicModel -from .openai_realtime import BidiOpenAIRealtimeModel __all__ = [ "BidiModel", "BidiModelTimeoutError", - "BidiGeminiLiveModel", "BidiNovaSonicModel", - "BidiOpenAIRealtimeModel", ] diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index a267211d1..ca69b9453 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -4,6 +4,7 @@ official Google GenAI SDK for simplified and robust WebSocket communication. Key improvements over custom WebSocket implementation: + - Uses official google-genai SDK with native Live API support - Simplified session management with client.aio.live.connect() - Built-in tool integration and event handling diff --git a/src/strands/experimental/bidi/models/nova_sonic.py b/src/strands/experimental/bidi/models/nova_sonic.py index 9ccc3d58f..0cfa51181 100644 --- a/src/strands/experimental/bidi/models/nova_sonic.py +++ b/src/strands/experimental/bidi/models/nova_sonic.py @@ -5,6 +5,7 @@ InvokeModelWithBidirectionalStream protocol. Nova Sonic specifics: + - Hierarchical event sequences: connectionStart → promptStart → content streaming - Base64-encoded audio format with hex encoding - Tool execution with content containers and identifier tracking From 69d8f09134045f16e4088ca2d66d77a26e248bb0 Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 18:30:09 -0500 Subject: [PATCH 5/6] address comments --- pyproject.toml | 2 -- src/strands/experimental/bidi/models/gemini_live.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 944a1b3a5..2a8b250fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,8 +75,6 @@ bidi = [ "prompt_toolkit>=3.0.0,<4.0.0", "pyaudio>=0.2.13,<1.0.0", "smithy-aws-core>=0.0.1; python_version>='3.12'", - "google-genai>=1.32.0,<2.0.0", - "websockets>=15.0.0,<16.0.0", ] bidi-gemini = ["google-genai>=1.32.0,<2.0.0"] bidi-openai = ["websockets>=15.0.0,<16.0.0"] diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index ca69b9453..dc3810520 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -222,6 +222,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut """Convert Gemini Live API events to provider-agnostic format. Handles different types of content: + - inputTranscription: User's speech transcribed to text - outputTranscription: Model's audio transcribed to text - modelTurn text: Text response from the model From 98294945ac3bbab041e459fb24c861d2eda9070c Mon Sep 17 00:00:00 2001 From: Rachit Mehta Date: Sun, 30 Nov 2025 18:38:12 -0500 Subject: [PATCH 6/6] minor update --- src/strands/experimental/bidi/models/gemini_live.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/strands/experimental/bidi/models/gemini_live.py b/src/strands/experimental/bidi/models/gemini_live.py index dc3810520..3af8d707f 100644 --- a/src/strands/experimental/bidi/models/gemini_live.py +++ b/src/strands/experimental/bidi/models/gemini_live.py @@ -222,7 +222,7 @@ def _convert_gemini_live_event(self, message: LiveServerMessage) -> list[BidiOut """Convert Gemini Live API events to provider-agnostic format. Handles different types of content: - + - inputTranscription: User's speech transcribed to text - outputTranscription: Model's audio transcribed to text - modelTurn text: Text response from the model