Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 221 additions & 7 deletions .hathora_build/app/serve_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "0"
os.environ["TQDM_DISABLE"] = "0"

from fastapi import FastAPI, File, Query, UploadFile
from fastapi import FastAPI, File, Query, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from starlette.websockets import WebSocketState
import nemo.collections.asr as nemo_asr
import numpy as np
import torch
import asyncio

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -190,24 +193,24 @@ def extract_audio_segment(input_path: str, start_time: Optional[float], end_time
"""Extract audio segment using ffmpeg if start_time or end_time is specified."""
if start_time is None and end_time is None:
return input_path

output_fd, output_path = tempfile.mkstemp(suffix=".wav")
os.close(output_fd)

cmd = ["ffmpeg", "-y", "-i", input_path]

if start_time is not None:
cmd.extend(["-ss", str(start_time)])

if end_time is not None:
if start_time is not None:
duration = end_time - start_time
cmd.extend(["-t", str(duration)])
else:
cmd.extend(["-to", str(end_time)])

cmd.extend(["-ac", "1", "-ar", "16000", output_path])

try:
subprocess.run(cmd, check=True, capture_output=True)
return output_path
Expand All @@ -216,6 +219,112 @@ def extract_audio_segment(input_path: str, start_time: Optional[float], end_time
raise RuntimeError(f"Failed to extract audio segment: {e.stderr.decode()}")


class StreamingASRProcessor:
"""
Handles real-time streaming ASR with buffered inference.
Uses NeMo's FrameBatchASR for efficient streaming with the parakeet-realtime-eou model.
"""

def __init__(self, model, chunk_len_in_secs: float = 0.16, buffer_len_in_secs: float = 1.6):
"""
Initialize streaming processor.

Args:
model: NeMo ASR model
chunk_len_in_secs: Length of each audio chunk (0.16s = 160ms for low latency)
buffer_len_in_secs: Total buffer length including context (1.6s recommended)
"""
self.model = model
self.chunk_len = chunk_len_in_secs
self.buffer_len = buffer_len_in_secs
self.sample_rate = 16000

# Try to import FrameBatchASR
try:
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchASR

# Initialize frame-based ASR for streaming
self.frame_asr = FrameBatchASR(
asr_model=model,
frame_len=chunk_len_in_secs,
total_buffer_in_secs=buffer_len_in_secs,
batch_size=1 # Process one stream at a time
)
self.use_frame_batch = True
logger.info(f"StreamingASRProcessor initialized with FrameBatchASR: chunk={chunk_len_in_secs}s, buffer={buffer_len_in_secs}s")
except ImportError:
# Fallback to basic transcription if FrameBatchASR not available
self.use_frame_batch = False
logger.warning("FrameBatchASR not available, using fallback transcription method")

def reset(self):
"""Reset the streaming buffer for a new session."""
if self.use_frame_batch:
self.frame_asr.reset()

def transcribe_chunk(self, audio_data: np.ndarray) -> Optional[str]:
"""
Transcribe a chunk of audio data.

Args:
audio_data: Float32 numpy array of audio samples (16kHz, mono)

Returns:
Transcription text if available, None otherwise
"""
try:
# Ensure audio is float32 normalized to [-1, 1]
if audio_data.dtype != np.float32:
if audio_data.dtype == np.int16:
audio_data = audio_data.astype(np.float32) / 32768.0
else:
audio_data = audio_data.astype(np.float32)

if self.use_frame_batch:
# Use FrameBatchASR for efficient streaming
hypotheses = self.frame_asr.transcribe_signal(audio_data)
if hypotheses and len(hypotheses) > 0:
text = hypotheses[0]
if text and text.strip():
return text
else:
# Fallback: Create temp WAV file and transcribe
tmp_fd, tmp_path = tempfile.mkstemp(suffix=".wav")
os.close(tmp_fd)

try:
# Convert to int16 for WAV
audio_int16 = (audio_data * 32768.0).astype(np.int16)

with wave.open(tmp_path, 'wb') as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(self.sample_rate)
wav_file.writeframes(audio_int16.tobytes())

# Transcribe
results = self.model.transcribe(
[tmp_path],
batch_size=1,
return_hypotheses=False,
verbose=False,
num_workers=0,
)
texts = extract_texts(results)
text = texts[0] if texts else ""
if text and text.strip():
return text
finally:
if os.path.exists(tmp_path):
os.remove(tmp_path)

return None

except Exception as e:
logger.error(f"Transcription error: {e}", exc_info=True)
return None


@app.post("/v1/transcribe")
async def transcribe(
file: UploadFile = File(...),
Expand Down Expand Up @@ -310,3 +419,108 @@ async def transcribe(
pass


@app.websocket("/v1/transcribe/stream")
async def transcribe_stream(websocket: WebSocket):
"""
WebSocket endpoint for real-time streaming transcription with EOU detection.

Protocol:
- Client sends: Binary frames containing raw PCM Int16 audio data (16kHz, mono, little-endian)
- Server responds: JSON messages with transcription results
{"type": "transcription", "text": "hello world", "is_eou": false}
{"type": "transcription", "text": "how are you<EOU>", "is_eou": true}
{"type": "error", "message": "error details"}

The parakeet-realtime-eou model emits <EOU> tokens to signal end-of-utterance,
enabling ultra-low latency (80-160ms) conversational experiences.
"""
await websocket.accept()
logger.info("WebSocket connection established for streaming transcription")

model = model_manager.get_model()

# Initialize streaming processor with low-latency settings
# chunk_len=0.16s gives ~160ms latency (matching model's design)
processor = StreamingASRProcessor(
model=model,
chunk_len_in_secs=0.16, # 160ms chunks for low latency
buffer_len_in_secs=1.6 # 1.6s buffer for context
)

# Audio configuration
sample_rate = 16000
bytes_per_sample = 2 # Int16

# Chunk size: 160ms = 2560 samples = 5120 bytes
samples_per_chunk = int(sample_rate * 0.16)
bytes_per_chunk = samples_per_chunk * bytes_per_sample

# Buffer for accumulating incoming audio
audio_buffer = bytearray()

# Track last transcription to avoid duplicates
last_transcription = ""

try:
while True:
# Receive audio data from client
data = await websocket.receive_bytes()
audio_buffer.extend(data)

# Process when we have enough audio for a chunk
while len(audio_buffer) >= bytes_per_chunk:
# Extract chunk
chunk_bytes = bytes(audio_buffer[:bytes_per_chunk])
audio_buffer = audio_buffer[bytes_per_chunk:]

# Convert bytes to numpy array
audio_int16 = np.frombuffer(chunk_bytes, dtype=np.int16)
audio_float32 = audio_int16.astype(np.float32) / 32768.0

# Transcribe the chunk
transcript = processor.transcribe_chunk(audio_float32)

if transcript and transcript != last_transcription:
# Check for EOU token
has_eou = EOU_TOKEN in transcript

# Send transcription result
response = {
"type": "transcription",
"text": transcript,
"is_eou": has_eou
}
await websocket.send_json(response)
logger.info(f"Transcription: '{transcript}' (EOU: {has_eou})")

last_transcription = transcript

# Reset processor after EOU for next utterance
if has_eou:
processor.reset()
last_transcription = ""

# Small delay to prevent busy loop
await asyncio.sleep(0.01)

except WebSocketDisconnect:
logger.info("WebSocket client disconnected")
except Exception as e:
logger.error(f"WebSocket error: {str(e)}", exc_info=True)
try:
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json({
"type": "error",
"message": str(e)
})
except Exception:
pass
finally:
try:
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.close()
except Exception:
pass
logger.info("WebSocket connection closed")