Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/mock_rt_server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import logging
import time
Expand Down Expand Up @@ -152,6 +153,19 @@ def dummy_add_transcript():
],
}

async def close_connection(ws):
"""Closes the connection after a delay."""
await asyncio.sleep(5)
await ws.send(json.dumps(
{
"message": "Error",
"format": "2.1",
"metadata": {"start_time": 0.0, "end_time": 1.0},
"type": "idle_timeout"
}
))
await ws.close(code=1008)
logging.info("Connection closed after 5 seconds")

async def mock_server_handler(websocket, logbook):
mock_server_handler.next_audio_seq_no = 1
Expand Down Expand Up @@ -214,6 +228,7 @@ def get_responses(message, is_binary=False):
},
}
)
asyncio.create_task(close_connection(websocket))
elif msg_name == "EndOfStream":
responses.append({"message": "EndOfTranscript"})
elif msg_name == "SetRecognitionConfig":
Expand Down
51 changes: 51 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,57 @@ def test_batch_client_api_key_constructor(mocker):
assert batch_client.connection_settings.generate_temp_token


class BlockingSyncStream:
def __init__(self):
self.content = [b"\x00", b"\x01", b"\x02", b"\x03", b"\x04"]

def read(self, _):
while True:
if self.content:
# Block forever if the end of the array is reached
return self.content.pop(0)


class AsyncStream(BlockingSyncStream):
def __init__(self):
super().__init__()

async def read(self, _):
while True:
if self.content:
return self.content.pop(0)
await asyncio.sleep(0.01) # Non-blocking wait


@pytest.mark.asyncio
async def test_blocking_stream_close(mock_server):
ws_client, transcription_config, audio_settings = default_ws_client_setup(
mock_server.url
)
stream = BlockingSyncStream()
try:
await asyncio.wait_for(
ws_client.run(stream, transcription_config, audio_settings), timeout=10
)
except asyncio.TimeoutError:
assert False, "The command failed to finish within the timeout period"


@pytest.mark.asyncio
@pytest.mark.xfail(reason="Expected to fail due to malformed input")
async def test_async_stream_close(mock_server):
ws_client, transcription_config, audio_settings = default_ws_client_setup(
mock_server.url
)
stream = AsyncStream()
try:
await asyncio.wait_for(
ws_client.run(stream, transcription_config, audio_settings), timeout=10
)
except asyncio.TimeoutError:
assert False, "The command failed to finish within the timeout period"


def deepcopy_state(obj):
"""
Return a deepcopy of the __dict__ (or state) of an object but ignore
Expand Down