diff --git a/src/agents/realtime/openai_realtime.py b/src/agents/realtime/openai_realtime.py index 236162622..4df0e93c6 100644 --- a/src/agents/realtime/openai_realtime.py +++ b/src/agents/realtime/openai_realtime.py @@ -12,6 +12,7 @@ import pydantic import websockets from openai.types.realtime import realtime_audio_config as _rt_audio_config +from openai.types.realtime.audio_transcription import AudioTranscription from openai.types.realtime.conversation_item import ( ConversationItem, ConversationItem as OpenAIConversationItem, @@ -87,6 +88,7 @@ from agents.realtime._default_tracker import ModelAudioTracker from agents.realtime.audio_formats import to_realtime_audio_format from agents.tool import FunctionTool, Tool +from agents.util._pydantic import coerce_model_with_literal_fallback from agents.util._types import MaybeAwaitable from ..exceptions import UserError @@ -883,7 +885,9 @@ def _get_session_config( "modalities", DEFAULT_MODEL_SETTINGS.get("modalities") ), audio=OpenAIRealtimeAudioConfig( - input=OpenAIRealtimeAudioInput(**audio_input_args), # type: ignore[arg-type] + input=OpenAIRealtimeAudioInput( + **_AudioTranscriptionHelper.prepare_audio_input_args(audio_input_args) + ), output=OpenAIRealtimeAudioOutput(**audio_output_args), # type: ignore[arg-type] ), tools=cast( @@ -958,6 +962,38 @@ async def connect(self, options: RealtimeModelConfig) -> None: await super().connect(sip_options) +class _AudioTranscriptionHelper: + """Helpers for handling transcription configs with forward compatibility.""" + + @staticmethod + def prepare_audio_input_args(audio_input_args: dict[str, Any]) -> dict[str, Any]: + """Prepare audio input args, allowing newer transcription model names.""" + prepared_args = dict(audio_input_args) + transcription_config = prepared_args.get("transcription") + if transcription_config is None: + return prepared_args + + prepared_args["transcription"] = _AudioTranscriptionHelper._coerce_audio_transcription( + transcription_config + ) + return prepared_args + + @staticmethod + def _coerce_audio_transcription(transcription_config: Any) -> Any: + """Convert transcription config into an AudioTranscription, tolerating new model names.""" + if isinstance(transcription_config, AudioTranscription): + return transcription_config + + if not isinstance(transcription_config, Mapping): + return transcription_config + + return coerce_model_with_literal_fallback( + AudioTranscription, + transcription_config, + literal_error_locs=[("model",), ("transcription", "model")], + ) + + class _ConversionHelper: @classmethod def conversation_item_to_realtime_message_item( diff --git a/src/agents/util/_pydantic.py b/src/agents/util/_pydantic.py new file mode 100644 index 000000000..53ae1edd9 --- /dev/null +++ b/src/agents/util/_pydantic.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from typing import Any, Protocol, TypeVar, cast + +import pydantic + +# Helpers to tolerate forward-compatible Pydantic literal changes (e.g., date-suffixed model names). + +_PydanticModelT = TypeVar("_PydanticModelT", bound="PydanticModelProtocol") + + +class PydanticModelProtocol(Protocol): + """Subset of the Pydantic API we need for validation and construction.""" + + @classmethod + def model_validate(cls: type[_PydanticModelT], data: Any) -> _PydanticModelT: ... + + @classmethod + def model_construct(cls: type[_PydanticModelT], **kwargs: Any) -> _PydanticModelT: ... + + +def coerce_model_with_literal_fallback( + model_cls: type[_PydanticModelT], + data: Any, + *, + literal_error_locs: Sequence[tuple[str, ...]], +) -> _PydanticModelT: + """Validate data and fall back to model_construct when literal errors occur.""" + if isinstance(data, model_cls): + return data + + if not isinstance(data, Mapping): + return cast(_PydanticModelT, data) + + try: + return model_cls.model_validate(data) + except pydantic.ValidationError as exc: + if _has_literal_error(exc, literal_error_locs): + return model_cls.model_construct(**dict(data)) + raise + + +def _has_literal_error( + exc: pydantic.ValidationError, literal_error_locs: Sequence[tuple[str, ...]] +) -> bool: + """Return True when a literal_error matches one of the provided locations.""" + literal_locs = set(literal_error_locs) + for error in exc.errors(): + if error.get("type") != "literal_error": + continue + + loc = tuple(error.get("loc") or ()) + if loc in literal_locs: + return True + + return False diff --git a/tests/realtime/test_openai_realtime.py b/tests/realtime/test_openai_realtime.py index 5954bbc93..ba9e90668 100644 --- a/tests/realtime/test_openai_realtime.py +++ b/tests/realtime/test_openai_realtime.py @@ -3,8 +3,10 @@ from typing import Any, cast from unittest.mock import AsyncMock, Mock, patch +import pydantic import pytest import websockets +from openai.types.realtime.audio_transcription import AudioTranscription from agents import Agent from agents.exceptions import UserError @@ -21,7 +23,10 @@ RealtimeModelSendToolOutput, RealtimeModelSendUserInput, ) -from agents.realtime.openai_realtime import OpenAIRealtimeWebSocketModel +from agents.realtime.openai_realtime import ( + OpenAIRealtimeWebSocketModel, + _AudioTranscriptionHelper, +) class TestOpenAIRealtimeWebSocketModel: @@ -646,6 +651,15 @@ def test_get_and_update_session_config(self, model): assert cfg.audio is not None and cfg.audio.output is not None assert cfg.audio.output.voice == "verse" + def test_session_config_allows_new_transcription_models(self, model): + cfg = model._get_session_config( + {"input_audio_transcription": {"model": "gpt-4o-mini-transcribe-2025-12-15"}} + ) + assert cfg.audio is not None + assert cfg.audio.input is not None + assert cfg.audio.input.transcription is not None + assert cfg.audio.input.transcription.model == "gpt-4o-mini-transcribe-2025-12-15" + def test_session_config_defaults_audio_formats_when_not_call(self, model): settings: dict[str, Any] = {} cfg = model._get_session_config(settings) @@ -657,6 +671,28 @@ def test_session_config_defaults_audio_formats_when_not_call(self, model): assert cfg.audio.output.format is not None assert cfg.audio.output.format.type == "audio/pcm" + def test_audio_transcription_helper_accepts_new_models(self): + args = {"transcription": {"model": "gpt-4o-mini-transcribe-2025-12-15"}} + prepared = _AudioTranscriptionHelper.prepare_audio_input_args(args) + assert prepared is not args + transcription = prepared["transcription"] + assert isinstance(transcription, AudioTranscription) + assert transcription.model is not None + assert str(transcription.model) == "gpt-4o-mini-transcribe-2025-12-15" + + def test_audio_transcription_helper_returns_copy_without_transcription(self): + args = {"format": "pcm16"} + prepared = _AudioTranscriptionHelper.prepare_audio_input_args(args) + assert prepared is not args + assert prepared == args + + def test_audio_transcription_helper_raises_on_non_literal_error(self): + # Non-literal validation errors should still surface to the caller. + with pytest.raises(pydantic.ValidationError): + _AudioTranscriptionHelper._coerce_audio_transcription( + {"model": "gpt-4o-mini-transcribe", "language": 123} # invalid language type + ) + def test_session_config_preserves_sip_audio_formats(self, model): model._call_id = "call-123" settings = { diff --git a/tests/util/test_pydantic.py b/tests/util/test_pydantic.py new file mode 100644 index 000000000..64e300d47 --- /dev/null +++ b/tests/util/test_pydantic.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Literal + +import pytest +from pydantic import BaseModel, ValidationError + +from agents.util._pydantic import coerce_model_with_literal_fallback + + +def test_coerce_model_with_literal_fallback_accepts_literal_miss(): + class LiteralToyModel(BaseModel): + kind: str + mode: Literal["a", "b"] + + obj = coerce_model_with_literal_fallback( + LiteralToyModel, + {"kind": "x", "mode": "c"}, + literal_error_locs=[("mode",)], + ) + assert isinstance(obj, LiteralToyModel) + assert str(obj.mode) == "c" + + +def test_coerce_model_with_literal_fallback_propagates_other_errors(): + class OtherModel(BaseModel): + field: int + + with pytest.raises(ValidationError): + coerce_model_with_literal_fallback(OtherModel, {"field": "oops"}, literal_error_locs=[])