diff --git a/.hathora_build/app/serve_asr.py b/.hathora_build/app/serve_asr.py index 4b3daff31e8a..d0cf9786d380 100644 --- a/.hathora_build/app/serve_asr.py +++ b/.hathora_build/app/serve_asr.py @@ -11,10 +11,13 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "0" os.environ["TQDM_DISABLE"] = "0" -from fastapi import FastAPI, File, Query, UploadFile +from fastapi import FastAPI, File, Query, UploadFile, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware +from starlette.websockets import WebSocketState import nemo.collections.asr as nemo_asr +import numpy as np import torch +import asyncio logging.basicConfig( level=logging.INFO, @@ -190,24 +193,24 @@ def extract_audio_segment(input_path: str, start_time: Optional[float], end_time """Extract audio segment using ffmpeg if start_time or end_time is specified.""" if start_time is None and end_time is None: return input_path - + output_fd, output_path = tempfile.mkstemp(suffix=".wav") os.close(output_fd) - + cmd = ["ffmpeg", "-y", "-i", input_path] - + if start_time is not None: cmd.extend(["-ss", str(start_time)]) - + if end_time is not None: if start_time is not None: duration = end_time - start_time cmd.extend(["-t", str(duration)]) else: cmd.extend(["-to", str(end_time)]) - + cmd.extend(["-ac", "1", "-ar", "16000", output_path]) - + try: subprocess.run(cmd, check=True, capture_output=True) return output_path @@ -216,6 +219,112 @@ def extract_audio_segment(input_path: str, start_time: Optional[float], end_time raise RuntimeError(f"Failed to extract audio segment: {e.stderr.decode()}") +class StreamingASRProcessor: + """ + Handles real-time streaming ASR with buffered inference. + Uses NeMo's FrameBatchASR for efficient streaming with the parakeet-realtime-eou model. + """ + + def __init__(self, model, chunk_len_in_secs: float = 0.16, buffer_len_in_secs: float = 1.6): + """ + Initialize streaming processor. + + Args: + model: NeMo ASR model + chunk_len_in_secs: Length of each audio chunk (0.16s = 160ms for low latency) + buffer_len_in_secs: Total buffer length including context (1.6s recommended) + """ + self.model = model + self.chunk_len = chunk_len_in_secs + self.buffer_len = buffer_len_in_secs + self.sample_rate = 16000 + + # Try to import FrameBatchASR + try: + from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR + + # Initialize frame-based ASR for streaming + self.frame_asr = FrameBatchASR( + asr_model=model, + frame_len=chunk_len_in_secs, + total_buffer_in_secs=buffer_len_in_secs, + batch_size=1 # Process one stream at a time + ) + self.use_frame_batch = True + logger.info(f"StreamingASRProcessor initialized with FrameBatchASR: chunk={chunk_len_in_secs}s, buffer={buffer_len_in_secs}s") + except ImportError: + # Fallback to basic transcription if FrameBatchASR not available + self.use_frame_batch = False + logger.warning("FrameBatchASR not available, using fallback transcription method") + + def reset(self): + """Reset the streaming buffer for a new session.""" + if self.use_frame_batch: + self.frame_asr.reset() + + def transcribe_chunk(self, audio_data: np.ndarray) -> Optional[str]: + """ + Transcribe a chunk of audio data. + + Args: + audio_data: Float32 numpy array of audio samples (16kHz, mono) + + Returns: + Transcription text if available, None otherwise + """ + try: + # Ensure audio is float32 normalized to [-1, 1] + if audio_data.dtype != np.float32: + if audio_data.dtype == np.int16: + audio_data = audio_data.astype(np.float32) / 32768.0 + else: + audio_data = audio_data.astype(np.float32) + + if self.use_frame_batch: + # Use FrameBatchASR for efficient streaming + hypotheses = self.frame_asr.transcribe_signal(audio_data) + if hypotheses and len(hypotheses) > 0: + text = hypotheses[0] + if text and text.strip(): + return text + else: + # Fallback: Create temp WAV file and transcribe + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".wav") + os.close(tmp_fd) + + try: + # Convert to int16 for WAV + audio_int16 = (audio_data * 32768.0).astype(np.int16) + + with wave.open(tmp_path, 'wb') as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(self.sample_rate) + wav_file.writeframes(audio_int16.tobytes()) + + # Transcribe + results = self.model.transcribe( + [tmp_path], + batch_size=1, + return_hypotheses=False, + verbose=False, + num_workers=0, + ) + texts = extract_texts(results) + text = texts[0] if texts else "" + if text and text.strip(): + return text + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + return None + + except Exception as e: + logger.error(f"Transcription error: {e}", exc_info=True) + return None + + @app.post("/v1/transcribe") async def transcribe( file: UploadFile = File(...), @@ -310,3 +419,108 @@ async def transcribe( pass +@app.websocket("/v1/transcribe/stream") +async def transcribe_stream(websocket: WebSocket): + """ + WebSocket endpoint for real-time streaming transcription with EOU detection. + + Protocol: + - Client sends: Binary frames containing raw PCM Int16 audio data (16kHz, mono, little-endian) + - Server responds: JSON messages with transcription results + {"type": "transcription", "text": "hello world", "is_eou": false} + {"type": "transcription", "text": "how are you", "is_eou": true} + {"type": "error", "message": "error details"} + + The parakeet-realtime-eou model emits tokens to signal end-of-utterance, + enabling ultra-low latency (80-160ms) conversational experiences. + """ + await websocket.accept() + logger.info("WebSocket connection established for streaming transcription") + + model = model_manager.get_model() + + # Initialize streaming processor with low-latency settings + # chunk_len=0.16s gives ~160ms latency (matching model's design) + processor = StreamingASRProcessor( + model=model, + chunk_len_in_secs=0.16, # 160ms chunks for low latency + buffer_len_in_secs=1.6 # 1.6s buffer for context + ) + + # Audio configuration + sample_rate = 16000 + bytes_per_sample = 2 # Int16 + + # Chunk size: 160ms = 2560 samples = 5120 bytes + samples_per_chunk = int(sample_rate * 0.16) + bytes_per_chunk = samples_per_chunk * bytes_per_sample + + # Buffer for accumulating incoming audio + audio_buffer = bytearray() + + # Track last transcription to avoid duplicates + last_transcription = "" + + try: + while True: + # Receive audio data from client + data = await websocket.receive_bytes() + audio_buffer.extend(data) + + # Process when we have enough audio for a chunk + while len(audio_buffer) >= bytes_per_chunk: + # Extract chunk + chunk_bytes = bytes(audio_buffer[:bytes_per_chunk]) + audio_buffer = audio_buffer[bytes_per_chunk:] + + # Convert bytes to numpy array + audio_int16 = np.frombuffer(chunk_bytes, dtype=np.int16) + audio_float32 = audio_int16.astype(np.float32) / 32768.0 + + # Transcribe the chunk + transcript = processor.transcribe_chunk(audio_float32) + + if transcript and transcript != last_transcription: + # Check for EOU token + has_eou = EOU_TOKEN in transcript + + # Send transcription result + response = { + "type": "transcription", + "text": transcript, + "is_eou": has_eou + } + await websocket.send_json(response) + logger.info(f"Transcription: '{transcript}' (EOU: {has_eou})") + + last_transcription = transcript + + # Reset processor after EOU for next utterance + if has_eou: + processor.reset() + last_transcription = "" + + # Small delay to prevent busy loop + await asyncio.sleep(0.01) + + except WebSocketDisconnect: + logger.info("WebSocket client disconnected") + except Exception as e: + logger.error(f"WebSocket error: {str(e)}", exc_info=True) + try: + if websocket.client_state == WebSocketState.CONNECTED: + await websocket.send_json({ + "type": "error", + "message": str(e) + }) + except Exception: + pass + finally: + try: + if websocket.client_state == WebSocketState.CONNECTED: + await websocket.close() + except Exception: + pass + logger.info("WebSocket connection closed") + +