Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion decart/lipsync/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@


class RealtimeLipsyncClient:

DECART_LIPSYNC_ENDPOINT = "/router/lipsync/ws"
VIDEO_FPS = 25

Expand Down
19 changes: 18 additions & 1 deletion decart/realtime/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, Optional
import asyncio
import logging
import uuid
from aiortc import MediaStreamTrack
Expand Down Expand Up @@ -81,7 +82,23 @@ def _emit_error(self, error: DecartSDKError) -> None:
async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
if not prompt or not prompt.strip():
raise InvalidInputError("Prompt cannot be empty")
await self._manager.send_message(PromptMessage(type="prompt", prompt=prompt))

event, result = self._manager.register_prompt_wait(prompt)

try:
await self._manager.send_message(
PromptMessage(type="prompt", prompt=prompt, enhance_prompt=enrich)
)

try:
await asyncio.wait_for(event.wait(), timeout=15.0)
except asyncio.TimeoutError:
raise DecartSDKError("Prompt acknowledgment timed out")

if not result["success"]:
raise DecartSDKError(result["error"] or "Prompt failed")
finally:
self._manager.unregister_prompt_wait(prompt)

def is_connected(self) -> bool:
return self._manager.is_connected()
Expand Down
14 changes: 12 additions & 2 deletions decart/realtime/messages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, Union, Annotated
from typing import Literal, Optional, Union, Annotated
from pydantic import BaseModel, Field, TypeAdapter

try:
Expand Down Expand Up @@ -42,9 +42,18 @@ class SessionIdMessage(BaseModel):
server_ip: str


class PromptAckMessage(BaseModel):
"""Acknowledgment for prompt update from server."""

type: Literal["prompt_ack"]
prompt: str
success: bool
error: Optional[str] = None


# Discriminated union for incoming messages
IncomingMessage = Annotated[
Union[AnswerMessage, IceCandidateMessage, SessionIdMessage],
Union[AnswerMessage, IceCandidateMessage, SessionIdMessage, PromptAckMessage],
Field(discriminator="type"),
]

Expand All @@ -67,6 +76,7 @@ class PromptMessage(BaseModel):

type: Literal["prompt"]
prompt: str
enhance_prompt: bool = True


# Outgoing message union (no discriminator needed - we know what we're sending)
Expand Down
22 changes: 21 additions & 1 deletion decart/realtime/webrtc_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
OfferMessage,
IceCandidateMessage,
IceCandidatePayload,
PromptAckMessage,
OutgoingMessage,
)
from .types import ConnectionState
Expand All @@ -36,7 +37,6 @@ def __init__(
on_error: Optional[Callable[[Exception], None]] = None,
customize_offer: Optional[Callable] = None,
):

self._pc: Optional[RTCPeerConnection] = None
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
self._session: Optional[aiohttp.ClientSession] = None
Expand All @@ -47,6 +47,7 @@ def __init__(
self._customize_offer = customize_offer
self._ws_task: Optional[asyncio.Task] = None
self._ice_candidates_queue: list[RTCIceCandidate] = []
self._pending_prompts: dict[str, tuple[asyncio.Event, dict]] = {}

async def connect(
self,
Expand Down Expand Up @@ -176,6 +177,8 @@ async def _handle_message(self, data: dict) -> None:
await self._handle_ice_candidate(message.candidate)
elif message.type == "session_id":
logger.debug(f"Session ID: {message.session_id}")
elif message.type == "prompt_ack":
self._handle_prompt_ack(message)

async def _handle_answer(self, sdp: str) -> None:
logger.debug("Received answer from server")
Expand Down Expand Up @@ -207,6 +210,23 @@ async def _handle_ice_candidate(self, candidate_data: IceCandidatePayload) -> No
logger.debug("Queuing ICE candidate (no remote description yet)")
self._ice_candidates_queue.append(candidate)

def _handle_prompt_ack(self, message: PromptAckMessage) -> None:
logger.debug(f"Received prompt_ack for: {message.prompt}, success: {message.success}")
if message.prompt in self._pending_prompts:
event, result = self._pending_prompts[message.prompt]
result["success"] = message.success
result["error"] = message.error
event.set()

def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]:
event = asyncio.Event()
result: dict = {"success": False, "error": None}
self._pending_prompts[prompt] = (event, result)
return event, result

def unregister_prompt_wait(self, prompt: str) -> None:
self._pending_prompts.pop(prompt, None)

async def _send_message(self, message: OutgoingMessage) -> None:
if not self._ws or self._ws.closed:
raise RuntimeError("WebSocket not connected")
Expand Down
7 changes: 7 additions & 0 deletions decart/realtime/webrtc_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import Optional, Callable
from dataclasses import dataclass
Expand Down Expand Up @@ -84,3 +85,9 @@ def is_connected(self) -> bool:

def get_connection_state(self) -> ConnectionState:
return self._connection.state

def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]:
return self._connection.register_prompt_wait(prompt)

def unregister_prompt_wait(self, prompt: str) -> None:
self._connection.unregister_prompt_wait(prompt)
1 change: 0 additions & 1 deletion examples/lipsync_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ async def process_lipsync(video_path: str, audio_path: str, output_path: str):
)
for i in range(frame_count):
try:

video_frame, audio_frame = await client.get_synced_output(timeout=1.0)
bgr_frame = cv2.cvtColor(video_frame, cv2.COLOR_RGB2BGR)
out.write(bgr_frame)
Expand Down
6 changes: 5 additions & 1 deletion examples/realtime_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ def on_error(error):
await asyncio.sleep(5)

print("\n🎨 Changing style to 'Cyberpunk city'...")
await realtime_client.set_prompt("Cyberpunk city")
try:
await realtime_client.set_prompt("Cyberpunk city")
print("✓ Prompt set successfully")
except Exception as e:
print(f"⚠️ Failed to set prompt: {e}")

await asyncio.sleep(5)

Expand Down
130 changes: 128 additions & 2 deletions tests/test_realtime_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,23 @@ def test_realtime_models_available():
@pytest.mark.asyncio
async def test_realtime_client_creation_with_mock():
"""Test client creation with mocked WebRTC"""
import asyncio

client = DecartClient(api_key="test-key")

with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
mock_manager = AsyncMock()
mock_manager.connect = AsyncMock(return_value=True)
mock_manager.is_connected = MagicMock(return_value=True)
mock_manager.get_connection_state = MagicMock(return_value="connected")
mock_manager.send_message = AsyncMock()

prompt_event = asyncio.Event()
prompt_result = {"success": True, "error": None}
prompt_event.set()

mock_manager.register_prompt_wait = MagicMock(return_value=(prompt_event, prompt_result))
mock_manager.unregister_prompt_wait = MagicMock()
mock_manager_class.return_value = mock_manager

mock_track = MagicMock()
Expand All @@ -76,13 +86,24 @@ async def test_realtime_client_creation_with_mock():

@pytest.mark.asyncio
async def test_realtime_set_prompt_with_mock():
"""Test set_prompt with mocked WebRTC"""
"""Test set_prompt with mocked WebRTC and prompt_ack"""
import asyncio

client = DecartClient(api_key="test-key")

with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
mock_manager = AsyncMock()
mock_manager.connect = AsyncMock(return_value=True)
mock_manager.send_message = AsyncMock()

prompt_event = asyncio.Event()
prompt_result = {"success": True, "error": None}

def register_prompt_wait(prompt):
return prompt_event, prompt_result

mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait)
mock_manager.unregister_prompt_wait = MagicMock()
mock_manager_class.return_value = mock_manager

mock_track = MagicMock()
Expand All @@ -99,12 +120,19 @@ async def test_realtime_set_prompt_with_mock():
),
)

async def set_event():
await asyncio.sleep(0.01)
prompt_event.set()

asyncio.create_task(set_event())
await realtime_client.set_prompt("New prompt")

mock_manager.send_message.assert_called_once()
mock_manager.send_message.assert_called()
call_args = mock_manager.send_message.call_args[0][0]
assert call_args.type == "prompt"
assert call_args.prompt == "New prompt"
assert call_args.enhance_prompt is True
mock_manager.unregister_prompt_wait.assert_called_with("New prompt")


@pytest.mark.asyncio
Expand Down Expand Up @@ -152,3 +180,101 @@ def on_error(error):
realtime_client._emit_error(test_error)
assert len(errors) == 1
assert errors[0].message == "Test error"


@pytest.mark.asyncio
async def test_realtime_set_prompt_timeout():
"""Test set_prompt raises on timeout"""
import asyncio

client = DecartClient(api_key="test-key")

with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
mock_manager = AsyncMock()
mock_manager.connect = AsyncMock(return_value=True)
mock_manager.send_message = AsyncMock()

prompt_event = asyncio.Event()
prompt_result = {"success": False, "error": None}

def register_prompt_wait(prompt):
return prompt_event, prompt_result

mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait)
mock_manager.unregister_prompt_wait = MagicMock()
mock_manager_class.return_value = mock_manager

mock_track = MagicMock()

from decart.realtime.types import RealtimeConnectOptions

realtime_client = await RealtimeClient.connect(
base_url=client.base_url,
api_key=client.api_key,
local_track=mock_track,
options=RealtimeConnectOptions(
model=models.realtime("mirage"),
on_remote_stream=lambda t: None,
),
)

from decart.errors import DecartSDKError

# Mock asyncio.wait_for to immediately raise TimeoutError
with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError):
with pytest.raises(DecartSDKError) as exc_info:
await realtime_client.set_prompt("New prompt")

assert "timed out" in str(exc_info.value)
mock_manager.unregister_prompt_wait.assert_called_with("New prompt")


@pytest.mark.asyncio
async def test_realtime_set_prompt_server_error():
"""Test set_prompt raises on server error"""
import asyncio

client = DecartClient(api_key="test-key")

with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
mock_manager = AsyncMock()
mock_manager.connect = AsyncMock(return_value=True)
mock_manager.send_message = AsyncMock()

prompt_event = asyncio.Event()
prompt_result = {"success": False, "error": "Server rejected prompt"}

def register_prompt_wait(prompt):
return prompt_event, prompt_result

mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait)
mock_manager.unregister_prompt_wait = MagicMock()
mock_manager_class.return_value = mock_manager

mock_track = MagicMock()

from decart.realtime.types import RealtimeConnectOptions

realtime_client = await RealtimeClient.connect(
base_url=client.base_url,
api_key=client.api_key,
local_track=mock_track,
options=RealtimeConnectOptions(
model=models.realtime("mirage"),
on_remote_stream=lambda t: None,
),
)

async def set_event():
await asyncio.sleep(0.01)
prompt_event.set()

asyncio.create_task(set_event())

from decart.errors import DecartSDKError

with pytest.raises(DecartSDKError) as exc_info:
await realtime_client.set_prompt("New prompt")

assert "Server rejected prompt" in str(exc_info.value)
mock_manager.unregister_prompt_wait.assert_called_with("New prompt")
4 changes: 2 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading