forked from GetStream/Vision-Agents
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconftest.py
More file actions
372 lines (282 loc) · 11 KB
/
conftest.py
File metadata and controls
372 lines (282 loc) · 11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
"""
Root conftest.py - Shared fixtures for all tests.
Pytest automatically discovers fixtures defined here and makes them
available to all tests in the project, including plugin tests.
"""
import asyncio
import os
from typing import Iterator
import numpy as np
import pytest
from blockbuster import BlockBuster, blockbuster_ctx
from dotenv import load_dotenv
from torchvision.io.video import av
from getstream.video.rtc.track_util import PcmData, AudioFormat
from vision_agents.core.stt.events import STTTranscriptEvent, STTErrorEvent, STTPartialTranscriptEvent
load_dotenv()
def skip_blockbuster(func_or_class):
"""Decorator to skip blockbuster checks for a test function or class.
Use this decorator when testing code that makes unavoidable blocking calls
(e.g., third-party SDKs like boto3, fish-audio-sdk).
Examples:
@skip_blockbuster
async def test_aws_function():
# boto3 makes blocking calls we can't fix
pass
@skip_blockbuster
class TestAWSIntegration:
# All tests in this class skip blockbuster
pass
"""
return pytest.mark.skip_blockbuster(func_or_class)
@pytest.fixture(autouse=True)
def blockbuster(request) -> Iterator[BlockBuster | None]:
"""Blockbuster fixture that detects blocking calls in async code.
Can be disabled for specific tests using the @skip_blockbuster decorator.
"""
# Check if test is marked to skip blockbuster
if request.node.get_closest_marker("skip_blockbuster"):
yield None
else:
with blockbuster_ctx() as bb:
yield bb
class STTSession:
"""Helper class for testing STT implementations.
Automatically subscribes to transcript and error events,
collects them, and provides a convenient wait method.
"""
def __init__(self, stt):
"""Initialize STT session with an STT object.
Args:
stt: STT implementation to monitor
"""
self.stt = stt
self.transcripts = []
self.partial_transcripts = []
self.errors = []
self._event = asyncio.Event()
# Subscribe to events
@stt.events.subscribe
async def on_transcript(event: STTTranscriptEvent):
self.transcripts.append(event)
self._event.set()
@stt.events.subscribe
async def on_partial_transcript(event: STTPartialTranscriptEvent):
self.partial_transcripts.append(event)
@stt.events.subscribe
async def on_error(event: STTErrorEvent):
self.errors.append(event.error)
self._event.set()
self._on_transcript = on_transcript
self._on_error = on_error
async def wait_for_result(self, timeout: float = 30.0):
"""Wait for either a transcript or error event.
Args:
timeout: Maximum time to wait in seconds
Raises:
asyncio.TimeoutError: If no result received within timeout
"""
# Allow event subscriptions to be processed
await asyncio.sleep(0.01)
# Wait for an event
await asyncio.wait_for(self._event.wait(), timeout=timeout)
def get_full_transcript(self) -> str:
"""Get full transcription text from all transcript events.
Returns:
Combined text from all transcripts
"""
return " ".join(t.text for t in self.transcripts)
def get_assets_dir():
"""Get the test assets directory path."""
return os.path.join(os.path.dirname(__file__), "tests", "test_assets")
@pytest.fixture(scope="session")
def assets_dir():
"""Fixture providing the test assets directory path."""
return get_assets_dir()
@pytest.fixture
def mia_audio_16khz():
"""Load mia.mp3 and convert to 16kHz PCM data."""
audio_file_path = os.path.join(get_assets_dir(), "mia.mp3")
# Load audio file using PyAV
container = av.open(audio_file_path)
audio_stream = container.streams.audio[0]
original_sample_rate = audio_stream.sample_rate
target_rate = 16000
# Create resampler if needed
resampler = None
if original_sample_rate != target_rate:
resampler = av.AudioResampler(format="s16", layout="mono", rate=target_rate)
# Read all audio frames
samples = []
for frame in container.decode(audio_stream):
# Resample if needed
if resampler:
frame = resampler.resample(frame)[0]
# Convert to numpy array
frame_array = frame.to_ndarray()
if len(frame_array.shape) > 1:
# Convert stereo to mono
frame_array = np.mean(frame_array, axis=0)
samples.append(frame_array)
# Concatenate all samples
samples = np.concatenate(samples)
# Convert to int16
samples = samples.astype(np.int16)
container.close()
# Create PCM data
pcm = PcmData(samples=samples, sample_rate=target_rate, format=AudioFormat.S16)
return pcm
@pytest.fixture
def mia_audio_48khz():
"""Load mia.mp3 and convert to 48kHz PCM data."""
audio_file_path = os.path.join(get_assets_dir(), "mia.mp3")
# Load audio file using PyAV
container = av.open(audio_file_path)
audio_stream = container.streams.audio[0]
original_sample_rate = audio_stream.sample_rate
target_rate = 48000
# Create resampler if needed
resampler = None
if original_sample_rate != target_rate:
resampler = av.AudioResampler(format="s16", layout="mono", rate=target_rate)
# Read all audio frames
samples = []
for frame in container.decode(audio_stream):
# Resample if needed
if resampler:
frame = resampler.resample(frame)[0]
# Convert to numpy array
frame_array = frame.to_ndarray()
if len(frame_array.shape) > 1:
# Convert stereo to mono
frame_array = np.mean(frame_array, axis=0)
samples.append(frame_array)
# Concatenate all samples
samples = np.concatenate(samples)
# Convert to int16
samples = samples.astype(np.int16)
container.close()
# Create PCM data
pcm = PcmData(samples=samples, sample_rate=target_rate, format=AudioFormat.S16)
return pcm
@pytest.fixture
def silence_2s_48khz():
"""Generate 2 seconds of silence at 48kHz PCM data."""
sample_rate = 48000
duration_seconds = 2.0
# Calculate number of samples for 2 seconds
num_samples = int(sample_rate * duration_seconds)
# Create silence (zeros) as int16
samples = np.zeros(num_samples, dtype=np.int16)
# Create PCM data
pcm = PcmData(samples=samples, sample_rate=sample_rate, format=AudioFormat.S16)
return pcm
@pytest.fixture
def mia_audio_48khz_chunked():
"""Load mia.mp3 and yield 48kHz PCM data in 20ms chunks."""
audio_file_path = os.path.join(get_assets_dir(), "mia.mp3")
# Load audio file using PyAV
container = av.open(audio_file_path)
audio_stream = container.streams.audio[0]
original_sample_rate = audio_stream.sample_rate
target_rate = 48000
# Create resampler if needed
resampler = None
if original_sample_rate != target_rate:
resampler = av.AudioResampler(format="s16", layout="mono", rate=target_rate)
# Read all audio frames
samples = []
for frame in container.decode(audio_stream):
# Resample if needed
if resampler:
frame = resampler.resample(frame)[0]
# Convert to numpy array
frame_array = frame.to_ndarray()
if len(frame_array.shape) > 1:
# Convert stereo to mono
frame_array = np.mean(frame_array, axis=0)
samples.append(frame_array)
# Concatenate all samples
samples = np.concatenate(samples)
# Convert to int16
samples = samples.astype(np.int16)
container.close()
# Calculate chunk size for 20ms at 48kHz
chunk_size = int(target_rate * 0.020) # 960 samples per 20ms
# Yield chunks of audio
chunks = []
for i in range(0, len(samples), chunk_size):
chunk_samples = samples[i : i + chunk_size]
# Create PCM data for this chunk
pcm_chunk = PcmData(
samples=chunk_samples, sample_rate=target_rate, format=AudioFormat.S16
)
chunks.append(pcm_chunk)
return chunks
@pytest.fixture
def golf_swing_image():
"""Load golf_swing.png image and return as bytes."""
image_file_path = os.path.join(get_assets_dir(), "golf_swing.png")
with open(image_file_path, "rb") as f:
image_bytes = f.read()
return image_bytes
@pytest.fixture
async def bunny_video_track():
"""Create RealVideoTrack from video file."""
from aiortc import VideoStreamTrack
video_file_path = os.path.join(get_assets_dir(), "bunny_3s.mp4")
class RealVideoTrack(VideoStreamTrack):
def __init__(self, video_path, max_frames=None):
super().__init__()
self.container = av.open(video_path)
self.video_stream = self.container.streams.video[0]
self.frame_count = 0
self.max_frames = max_frames
self.frame_duration = 1.0 / 15.0 # 15 fps
async def recv(self):
if self.max_frames is not None and self.frame_count >= self.max_frames:
raise asyncio.CancelledError("No more frames")
try:
for frame in self.container.decode(self.video_stream):
if frame is None:
raise asyncio.CancelledError("End of video stream")
self.frame_count += 1
frame = frame.to_rgb()
await asyncio.sleep(self.frame_duration)
return frame
raise asyncio.CancelledError("End of video stream")
except asyncio.CancelledError:
raise
except Exception as e:
if "End of file" in str(e) or "avcodec_send_packet" in str(e):
raise asyncio.CancelledError("End of video stream")
else:
print(f"Error reading video frame: {e}")
raise asyncio.CancelledError("Video read error")
track = RealVideoTrack(video_file_path, max_frames=None)
try:
yield track
finally:
track.container.close()
@pytest.fixture
async def audio_track_48khz():
"""Create audio track that produces 48kHz audio frames."""
from getstream.video.rtc.audio_track import AudioStreamTrack
audio_file_path = os.path.join(get_assets_dir(), "formant_speech_48k.wav")
class TestAudioTrack(AudioStreamTrack):
def __init__(self, audio_path):
super().__init__()
self.container = av.open(audio_path)
self.audio_stream = self.container.streams.audio[0]
self.decoder = self.container.decode(self.audio_stream)
async def recv(self):
try:
frame = next(self.decoder)
return frame
except StopIteration:
raise asyncio.CancelledError("End of audio stream")
track = TestAudioTrack(audio_file_path)
try:
yield track
finally:
track.container.close()