diff --git a/decart/lipsync/client.py b/decart/lipsync/client.py index 3f6dbcf..d3bafea 100644 --- a/decart/lipsync/client.py +++ b/decart/lipsync/client.py @@ -23,7 +23,6 @@ class RealtimeLipsyncClient: - DECART_LIPSYNC_ENDPOINT = "/router/lipsync/ws" VIDEO_FPS = 25 diff --git a/decart/realtime/client.py b/decart/realtime/client.py index 4b4a4f9..37d79b2 100644 --- a/decart/realtime/client.py +++ b/decart/realtime/client.py @@ -1,4 +1,5 @@ from typing import Callable, Optional +import asyncio import logging import uuid from aiortc import MediaStreamTrack @@ -81,7 +82,23 @@ def _emit_error(self, error: DecartSDKError) -> None: async def set_prompt(self, prompt: str, enrich: bool = True) -> None: if not prompt or not prompt.strip(): raise InvalidInputError("Prompt cannot be empty") - await self._manager.send_message(PromptMessage(type="prompt", prompt=prompt)) + + event, result = self._manager.register_prompt_wait(prompt) + + try: + await self._manager.send_message( + PromptMessage(type="prompt", prompt=prompt, enhance_prompt=enrich) + ) + + try: + await asyncio.wait_for(event.wait(), timeout=15.0) + except asyncio.TimeoutError: + raise DecartSDKError("Prompt acknowledgment timed out") + + if not result["success"]: + raise DecartSDKError(result["error"] or "Prompt failed") + finally: + self._manager.unregister_prompt_wait(prompt) def is_connected(self) -> bool: return self._manager.is_connected() diff --git a/decart/realtime/messages.py b/decart/realtime/messages.py index 81f05af..c59c5c6 100644 --- a/decart/realtime/messages.py +++ b/decart/realtime/messages.py @@ -1,4 +1,4 @@ -from typing import Literal, Union, Annotated +from typing import Literal, Optional, Union, Annotated from pydantic import BaseModel, Field, TypeAdapter try: @@ -42,9 +42,18 @@ class SessionIdMessage(BaseModel): server_ip: str +class PromptAckMessage(BaseModel): + """Acknowledgment for prompt update from server.""" + + type: Literal["prompt_ack"] + prompt: str + success: bool + error: Optional[str] = None + + # Discriminated union for incoming messages IncomingMessage = Annotated[ - Union[AnswerMessage, IceCandidateMessage, SessionIdMessage], + Union[AnswerMessage, IceCandidateMessage, SessionIdMessage, PromptAckMessage], Field(discriminator="type"), ] @@ -67,6 +76,7 @@ class PromptMessage(BaseModel): type: Literal["prompt"] prompt: str + enhance_prompt: bool = True # Outgoing message union (no discriminator needed - we know what we're sending) diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index b664b27..ad087e4 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -21,6 +21,7 @@ OfferMessage, IceCandidateMessage, IceCandidatePayload, + PromptAckMessage, OutgoingMessage, ) from .types import ConnectionState @@ -36,7 +37,6 @@ def __init__( on_error: Optional[Callable[[Exception], None]] = None, customize_offer: Optional[Callable] = None, ): - self._pc: Optional[RTCPeerConnection] = None self._ws: Optional[aiohttp.ClientWebSocketResponse] = None self._session: Optional[aiohttp.ClientSession] = None @@ -47,6 +47,7 @@ def __init__( self._customize_offer = customize_offer self._ws_task: Optional[asyncio.Task] = None self._ice_candidates_queue: list[RTCIceCandidate] = [] + self._pending_prompts: dict[str, tuple[asyncio.Event, dict]] = {} async def connect( self, @@ -176,6 +177,8 @@ async def _handle_message(self, data: dict) -> None: await self._handle_ice_candidate(message.candidate) elif message.type == "session_id": logger.debug(f"Session ID: {message.session_id}") + elif message.type == "prompt_ack": + self._handle_prompt_ack(message) async def _handle_answer(self, sdp: str) -> None: logger.debug("Received answer from server") @@ -207,6 +210,23 @@ async def _handle_ice_candidate(self, candidate_data: IceCandidatePayload) -> No logger.debug("Queuing ICE candidate (no remote description yet)") self._ice_candidates_queue.append(candidate) + def _handle_prompt_ack(self, message: PromptAckMessage) -> None: + logger.debug(f"Received prompt_ack for: {message.prompt}, success: {message.success}") + if message.prompt in self._pending_prompts: + event, result = self._pending_prompts[message.prompt] + result["success"] = message.success + result["error"] = message.error + event.set() + + def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]: + event = asyncio.Event() + result: dict = {"success": False, "error": None} + self._pending_prompts[prompt] = (event, result) + return event, result + + def unregister_prompt_wait(self, prompt: str) -> None: + self._pending_prompts.pop(prompt, None) + async def _send_message(self, message: OutgoingMessage) -> None: if not self._ws or self._ws.closed: raise RuntimeError("WebSocket not connected") diff --git a/decart/realtime/webrtc_manager.py b/decart/realtime/webrtc_manager.py index c798833..d39a067 100644 --- a/decart/realtime/webrtc_manager.py +++ b/decart/realtime/webrtc_manager.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import Optional, Callable from dataclasses import dataclass @@ -84,3 +85,9 @@ def is_connected(self) -> bool: def get_connection_state(self) -> ConnectionState: return self._connection.state + + def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]: + return self._connection.register_prompt_wait(prompt) + + def unregister_prompt_wait(self, prompt: str) -> None: + self._connection.unregister_prompt_wait(prompt) diff --git a/examples/lipsync_file.py b/examples/lipsync_file.py index 4b0d0b2..006f95c 100644 --- a/examples/lipsync_file.py +++ b/examples/lipsync_file.py @@ -72,7 +72,6 @@ async def process_lipsync(video_path: str, audio_path: str, output_path: str): ) for i in range(frame_count): try: - video_frame, audio_frame = await client.get_synced_output(timeout=1.0) bgr_frame = cv2.cvtColor(video_frame, cv2.COLOR_RGB2BGR) out.write(bgr_frame) diff --git a/examples/realtime_synthetic.py b/examples/realtime_synthetic.py index 6a27b98..b251814 100644 --- a/examples/realtime_synthetic.py +++ b/examples/realtime_synthetic.py @@ -126,7 +126,11 @@ def on_error(error): await asyncio.sleep(5) print("\n🎨 Changing style to 'Cyberpunk city'...") - await realtime_client.set_prompt("Cyberpunk city") + try: + await realtime_client.set_prompt("Cyberpunk city") + print("✓ Prompt set successfully") + except Exception as e: + print(f"⚠️ Failed to set prompt: {e}") await asyncio.sleep(5) diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index 768149f..fee649d 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -44,6 +44,8 @@ def test_realtime_models_available(): @pytest.mark.asyncio async def test_realtime_client_creation_with_mock(): """Test client creation with mocked WebRTC""" + import asyncio + client = DecartClient(api_key="test-key") with patch("decart.realtime.client.WebRTCManager") as mock_manager_class: @@ -51,6 +53,14 @@ async def test_realtime_client_creation_with_mock(): mock_manager.connect = AsyncMock(return_value=True) mock_manager.is_connected = MagicMock(return_value=True) mock_manager.get_connection_state = MagicMock(return_value="connected") + mock_manager.send_message = AsyncMock() + + prompt_event = asyncio.Event() + prompt_result = {"success": True, "error": None} + prompt_event.set() + + mock_manager.register_prompt_wait = MagicMock(return_value=(prompt_event, prompt_result)) + mock_manager.unregister_prompt_wait = MagicMock() mock_manager_class.return_value = mock_manager mock_track = MagicMock() @@ -76,13 +86,24 @@ async def test_realtime_client_creation_with_mock(): @pytest.mark.asyncio async def test_realtime_set_prompt_with_mock(): - """Test set_prompt with mocked WebRTC""" + """Test set_prompt with mocked WebRTC and prompt_ack""" + import asyncio + client = DecartClient(api_key="test-key") with patch("decart.realtime.client.WebRTCManager") as mock_manager_class: mock_manager = AsyncMock() mock_manager.connect = AsyncMock(return_value=True) mock_manager.send_message = AsyncMock() + + prompt_event = asyncio.Event() + prompt_result = {"success": True, "error": None} + + def register_prompt_wait(prompt): + return prompt_event, prompt_result + + mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait) + mock_manager.unregister_prompt_wait = MagicMock() mock_manager_class.return_value = mock_manager mock_track = MagicMock() @@ -99,12 +120,19 @@ async def test_realtime_set_prompt_with_mock(): ), ) + async def set_event(): + await asyncio.sleep(0.01) + prompt_event.set() + + asyncio.create_task(set_event()) await realtime_client.set_prompt("New prompt") - mock_manager.send_message.assert_called_once() + mock_manager.send_message.assert_called() call_args = mock_manager.send_message.call_args[0][0] assert call_args.type == "prompt" assert call_args.prompt == "New prompt" + assert call_args.enhance_prompt is True + mock_manager.unregister_prompt_wait.assert_called_with("New prompt") @pytest.mark.asyncio @@ -152,3 +180,101 @@ def on_error(error): realtime_client._emit_error(test_error) assert len(errors) == 1 assert errors[0].message == "Test error" + + +@pytest.mark.asyncio +async def test_realtime_set_prompt_timeout(): + """Test set_prompt raises on timeout""" + import asyncio + + client = DecartClient(api_key="test-key") + + with patch("decart.realtime.client.WebRTCManager") as mock_manager_class: + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + prompt_event = asyncio.Event() + prompt_result = {"success": False, "error": None} + + def register_prompt_wait(prompt): + return prompt_event, prompt_result + + mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait) + mock_manager.unregister_prompt_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("mirage"), + on_remote_stream=lambda t: None, + ), + ) + + from decart.errors import DecartSDKError + + # Mock asyncio.wait_for to immediately raise TimeoutError + with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError): + with pytest.raises(DecartSDKError) as exc_info: + await realtime_client.set_prompt("New prompt") + + assert "timed out" in str(exc_info.value) + mock_manager.unregister_prompt_wait.assert_called_with("New prompt") + + +@pytest.mark.asyncio +async def test_realtime_set_prompt_server_error(): + """Test set_prompt raises on server error""" + import asyncio + + client = DecartClient(api_key="test-key") + + with patch("decart.realtime.client.WebRTCManager") as mock_manager_class: + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.send_message = AsyncMock() + + prompt_event = asyncio.Event() + prompt_result = {"success": False, "error": "Server rejected prompt"} + + def register_prompt_wait(prompt): + return prompt_event, prompt_result + + mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait) + mock_manager.unregister_prompt_wait = MagicMock() + mock_manager_class.return_value = mock_manager + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("mirage"), + on_remote_stream=lambda t: None, + ), + ) + + async def set_event(): + await asyncio.sleep(0.01) + prompt_event.set() + + asyncio.create_task(set_event()) + + from decart.errors import DecartSDKError + + with pytest.raises(DecartSDKError) as exc_info: + await realtime_client.set_prompt("New prompt") + + assert "Server rejected prompt" in str(exc_info.value) + mock_manager.unregister_prompt_wait.assert_called_with("New prompt") diff --git a/uv.lock b/uv.lock index 0b2dadf..a13836a 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.9" resolution-markers = [ "python_full_version >= '3.12' and sys_platform == 'darwin'", @@ -900,7 +900,7 @@ wheels = [ [[package]] name = "decart" -version = "0.0.8" +version = "0.0.11" source = { editable = "." } dependencies = [ { name = "aiofiles" },