diff --git a/tests/mock_rt_server.py b/tests/mock_rt_server.py index 387c452..a57773a 100644 --- a/tests/mock_rt_server.py +++ b/tests/mock_rt_server.py @@ -1,3 +1,4 @@ +import asyncio import json import logging import time @@ -152,6 +153,19 @@ def dummy_add_transcript(): ], } +async def close_connection(ws): + """Closes the connection after a delay.""" + await asyncio.sleep(5) + await ws.send(json.dumps( + { + "message": "Error", + "format": "2.1", + "metadata": {"start_time": 0.0, "end_time": 1.0}, + "type": "idle_timeout" + } + )) + await ws.close(code=1008) + logging.info("Connection closed after 5 seconds") async def mock_server_handler(websocket, logbook): mock_server_handler.next_audio_seq_no = 1 @@ -214,6 +228,7 @@ def get_responses(message, is_binary=False): }, } ) + asyncio.create_task(close_connection(websocket)) elif msg_name == "EndOfStream": responses.append({"message": "EndOfTranscript"}) elif msg_name == "SetRecognitionConfig": diff --git a/tests/test_client.py b/tests/test_client.py index 29361c7..59f0b05 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -716,6 +716,57 @@ def test_batch_client_api_key_constructor(mocker): assert batch_client.connection_settings.generate_temp_token +class BlockingSyncStream: + def __init__(self): + self.content = [b"\x00", b"\x01", b"\x02", b"\x03", b"\x04"] + + def read(self, _): + while True: + if self.content: + # Block forever if the end of the array is reached + return self.content.pop(0) + + +class AsyncStream(BlockingSyncStream): + def __init__(self): + super().__init__() + + async def read(self, _): + while True: + if self.content: + return self.content.pop(0) + await asyncio.sleep(0.01) # Non-blocking wait + + +@pytest.mark.asyncio +async def test_blocking_stream_close(mock_server): + ws_client, transcription_config, audio_settings = default_ws_client_setup( + mock_server.url + ) + stream = BlockingSyncStream() + try: + await asyncio.wait_for( + ws_client.run(stream, transcription_config, audio_settings), timeout=10 + ) + except asyncio.TimeoutError: + assert False, "The command failed to finish within the timeout period" + + +@pytest.mark.asyncio +@pytest.mark.xfail(reason="Expected to fail due to malformed input") +async def test_async_stream_close(mock_server): + ws_client, transcription_config, audio_settings = default_ws_client_setup( + mock_server.url + ) + stream = AsyncStream() + try: + await asyncio.wait_for( + ws_client.run(stream, transcription_config, audio_settings), timeout=10 + ) + except asyncio.TimeoutError: + assert False, "The command failed to finish within the timeout period" + + def deepcopy_state(obj): """ Return a deepcopy of the __dict__ (or state) of an object but ignore