Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 38 additions & 32 deletions src/server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json

import numpy as np
import torch

from flask import Flask, send_from_directory, request, jsonify
from flask_cors import CORS, cross_origin
Expand All @@ -12,12 +12,16 @@
top_phonetic_errors,
pair_by_words,
)
from transcription import transcribe_timestamped, SAMPLE_RATE
from transcription import (
extract_features_only,
run_transformer_on_features,
STRIDE_SIZE,
)
from phoneme_utils import TIMESTAMPED_PHONES_T, TIMESTAMPED_PHONES_BY_WORD_T

# Constants
DEBUG = False
NUM_SECONDS_PER_CHUNK = 0.5
TRANSFORMER_INTERVAL = 30
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll definitely want this to be adaptive. Not an immediate priority. Make sure you write good benchmarking code first so you can measure whether each change is an improvement. Then we can also combine with the VAD and local agreement optimizations

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrote some suggested utils for benchmarking code here. Make sure to add accuracy, average over multiple samples, and create nice figures for a blog post


# Initialize Flask app
app = Flask(__name__)
Expand Down Expand Up @@ -83,44 +87,46 @@ def get_score_words_cer():

@sock.route("/stream")
def stream(ws):
buffer = b"" # Buffer to hold audio chunks
buffer = b""
feature_list = []
total_samples_processed = 0

full_transcription: TIMESTAMPED_PHONES_T = []
accumulated_duration = 0
combined = np.array([], dtype=np.float32)
while True:
try:
# Receive audio data from the client
data = ws.receive()
if data and data != "stop":
buffer += data

# Process when buffer has at least one chunk in it or when we are done
if (
data == "stop"
or len(buffer)
>= SAMPLE_RATE * NUM_SECONDS_PER_CHUNK * np.dtype(np.float32).itemsize
):
audio = np.frombuffer(buffer, dtype=np.float32)
transcription = transcribe_timestamped(audio, accumulated_duration)
accumulated_duration += len(audio) / SAMPLE_RATE
full_transcription.extend(transcription)
ws.send(json.dumps(full_transcription))

if DEBUG:
from scipy.io import wavfile

wavfile.write("src/audio.wav", SAMPLE_RATE, audio)
combined = np.concatenate([combined, audio])
wavfile.write("src/combined.wav", SAMPLE_RATE, combined)

if data == "stop":
break

buffer = b"" # Clear the buffer
# Process 20ms chunks
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 20ms is the stride. The CNN actually has a receptive field of 25ms (400 samples). By only ever feeding it 20ms chunks and padding them to fit the 25ms (MIN_LEN_SAMPLES), you never allow it to read all the audio it needs to perform equivalently to running the full model once at the end. You'll want to let it collect some multiple of 320 samples + 80 samples in each chunk or, if you prefer to allow the last CNN receptive field to be incomplete, you should recompute it when more audio comes in

while len(buffer) >= STRIDE_SIZE * np.dtype(np.float32).itemsize:
chunk_bytes = buffer[: STRIDE_SIZE * np.dtype(np.float32).itemsize]
buffer = buffer[STRIDE_SIZE * np.dtype(np.float32).itemsize :]

audio_chunk = np.frombuffer(chunk_bytes, dtype=np.float32)

features, samples = extract_features_only(audio_chunk)
feature_list.append(features)
total_samples_processed += samples
# accumulate features for 500ms (25 sets of 20ms), then send COMPLETE transcription from start
if len(feature_list) % TRANSFORMER_INTERVAL == 0:
all_features = torch.cat(feature_list, dim=1)
full_transcription = run_transformer_on_features(
all_features, total_samples_processed
)
ws.send(json.dumps(full_transcription))

if data == "stop":
# Final update with any remaining features
if feature_list:
all_features = torch.cat(feature_list, dim=1)
full_transcription = run_transformer_on_features(
all_features, total_samples_processed
)
ws.send(json.dumps(full_transcription))
break

except Exception as e:
print(f"Error: {e}")
print(f"Line: {e.__traceback__.tb_lineno if e.__traceback__ else -1}")
break


Expand Down
78 changes: 57 additions & 21 deletions src/transcription.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,89 @@
import torch
import numpy as np
from transformers import AutoProcessor, AutoModelForCTC
from transformers import (
AutoProcessor,
AutoModelForCTC,
Wav2Vec2Processor,
Wav2Vec2ForCTC,
)
from phoneme_utils import TIMESTAMPED_PHONES_T

SAMPLE_RATE = 16_000

# Load Wav2Vec2 model
model_id = "KoelLabs/xlsr-english-01"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForCTC.from_pretrained(model_id)
processor: Wav2Vec2Processor = AutoProcessor.from_pretrained(model_id)
model: Wav2Vec2ForCTC = AutoModelForCTC.from_pretrained(model_id)
assert processor.feature_extractor.sampling_rate == SAMPLE_RATE


def transcribe_timestamped(audio: np.ndarray, time_offset=0.0) -> TIMESTAMPED_PHONES_T:
input_values = (
processor(
audio,
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
padding=True,
)
.input_values.type(torch.float32)
.to(model.device)
def _calculate_cnn_window(model: Wav2Vec2ForCTC):
receptive_field = 1
stride = 1
for conv_layer in model.wav2vec2.feature_extractor.conv_layers:
assert hasattr(conv_layer, "conv")
conv = conv_layer.conv
assert isinstance(conv, torch.nn.Conv1d)
receptive_field += (conv.kernel_size[0] - 1) * stride
stride *= conv.stride[0]
return receptive_field, stride


RECEPTIVE_FIELD_SIZE, STRIDE_SIZE = _calculate_cnn_window(model)


def extract_features_only(audio: np.ndarray):
"""Extract CNN features and project to encoder hidden size (transformer-ready)."""
# True raw sample count before any padding
raw_sample_count = int(np.asarray(audio).shape[-1])

inputs = processor(
audio,
sampling_rate=SAMPLE_RATE,
return_tensors="pt",
padding="max_length",
max_length=RECEPTIVE_FIELD_SIZE,
)
input_values = inputs.input_values.type(torch.float32).to(model.device)
with torch.no_grad():
logits = model(input_values).logits
conv_feats = model.wav2vec2.feature_extractor(input_values) # (B, C, T')
conv_feats_t = conv_feats.transpose(1, 2) # (B, T', C)
# Project to hidden size for transformer; also returns normalized conv features
features, normed_conv_feats = model.wav2vec2.feature_projection(conv_feats_t)
# Return transformer-ready features and original (unpadded) input length in samples
return features, raw_sample_count

predicted_ids = torch.argmax(logits, dim=-1)[0].tolist()
duration_sec = input_values.shape[1] / processor.feature_extractor.sampling_rate

def run_transformer_on_features(
features: torch.Tensor, total_audio_samples: int, time_offset: float = 0.0
) -> TIMESTAMPED_PHONES_T:
"""Run transformer and decode"""
# slowest step
with torch.no_grad():
encoder_outputs = model.wav2vec2.encoder(features)
logits = model.lm_head(encoder_outputs[0])

predicted_ids = torch.argmax(logits, dim=-1)[0].tolist()
# Use original audio length in samples to compute duration
duration_sec = total_audio_samples / processor.feature_extractor.sampling_rate
ids_w_time = [
(time_offset + i / len(predicted_ids) * duration_sec, _id)
for i, _id in enumerate(predicted_ids)
]

current_phoneme_id = processor.tokenizer.pad_token_id
current_start_time = 0
phonemes_with_time = []
for time, _id in ids_w_time:
for timestamp, _id in ids_w_time:
if current_phoneme_id != _id:
if current_phoneme_id != processor.tokenizer.pad_token_id:
phonemes_with_time.append(
(
processor.decode(current_phoneme_id),
current_start_time,
time,
timestamp,
)
)
current_start_time = time
current_phoneme_id = _id

current_start_time = timestamp
current_phoneme_id = _id
return phonemes_with_time
Loading