-
Notifications
You must be signed in to change notification settings - Fork 3
3 stream to ipa #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
3 stream to ipa #13
Changes from all commits
be0b0d9
a659b8a
aa3592b
cdefb7b
801f74b
a39c125
09a1a69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,6 @@ | ||
| import json | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from flask import Flask, send_from_directory, request, jsonify | ||
| from flask_cors import CORS, cross_origin | ||
|
|
@@ -12,12 +12,16 @@ | |
| top_phonetic_errors, | ||
| pair_by_words, | ||
| ) | ||
| from transcription import transcribe_timestamped, SAMPLE_RATE | ||
| from transcription import ( | ||
| extract_features_only, | ||
| run_transformer_on_features, | ||
| STRIDE_SIZE, | ||
| ) | ||
| from phoneme_utils import TIMESTAMPED_PHONES_T, TIMESTAMPED_PHONES_BY_WORD_T | ||
|
|
||
| # Constants | ||
| DEBUG = False | ||
| NUM_SECONDS_PER_CHUNK = 0.5 | ||
| TRANSFORMER_INTERVAL = 30 | ||
|
|
||
| # Initialize Flask app | ||
| app = Flask(__name__) | ||
|
|
@@ -83,44 +87,46 @@ def get_score_words_cer(): | |
|
|
||
| @sock.route("/stream") | ||
| def stream(ws): | ||
| buffer = b"" # Buffer to hold audio chunks | ||
| buffer = b"" | ||
| feature_list = [] | ||
| total_samples_processed = 0 | ||
|
|
||
| full_transcription: TIMESTAMPED_PHONES_T = [] | ||
| accumulated_duration = 0 | ||
| combined = np.array([], dtype=np.float32) | ||
| while True: | ||
| try: | ||
| # Receive audio data from the client | ||
| data = ws.receive() | ||
| if data and data != "stop": | ||
| buffer += data | ||
|
|
||
| # Process when buffer has at least one chunk in it or when we are done | ||
| if ( | ||
| data == "stop" | ||
| or len(buffer) | ||
| >= SAMPLE_RATE * NUM_SECONDS_PER_CHUNK * np.dtype(np.float32).itemsize | ||
| ): | ||
| audio = np.frombuffer(buffer, dtype=np.float32) | ||
| transcription = transcribe_timestamped(audio, accumulated_duration) | ||
| accumulated_duration += len(audio) / SAMPLE_RATE | ||
| full_transcription.extend(transcription) | ||
| ws.send(json.dumps(full_transcription)) | ||
|
|
||
| if DEBUG: | ||
| from scipy.io import wavfile | ||
|
|
||
| wavfile.write("src/audio.wav", SAMPLE_RATE, audio) | ||
| combined = np.concatenate([combined, audio]) | ||
| wavfile.write("src/combined.wav", SAMPLE_RATE, combined) | ||
|
|
||
| if data == "stop": | ||
| break | ||
|
|
||
| buffer = b"" # Clear the buffer | ||
| # Process 20ms chunks | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The 20ms is the stride. The CNN actually has a receptive field of 25ms (400 samples). By only ever feeding it 20ms chunks and padding them to fit the 25ms ( |
||
| while len(buffer) >= STRIDE_SIZE * np.dtype(np.float32).itemsize: | ||
| chunk_bytes = buffer[: STRIDE_SIZE * np.dtype(np.float32).itemsize] | ||
| buffer = buffer[STRIDE_SIZE * np.dtype(np.float32).itemsize :] | ||
|
|
||
| audio_chunk = np.frombuffer(chunk_bytes, dtype=np.float32) | ||
|
|
||
| features, samples = extract_features_only(audio_chunk) | ||
| feature_list.append(features) | ||
| total_samples_processed += samples | ||
| # accumulate features for 500ms (25 sets of 20ms), then send COMPLETE transcription from start | ||
| if len(feature_list) % TRANSFORMER_INTERVAL == 0: | ||
| all_features = torch.cat(feature_list, dim=1) | ||
| full_transcription = run_transformer_on_features( | ||
| all_features, total_samples_processed | ||
| ) | ||
| ws.send(json.dumps(full_transcription)) | ||
|
|
||
| if data == "stop": | ||
| # Final update with any remaining features | ||
| if feature_list: | ||
| all_features = torch.cat(feature_list, dim=1) | ||
| full_transcription = run_transformer_on_features( | ||
| all_features, total_samples_processed | ||
| ) | ||
| ws.send(json.dumps(full_transcription)) | ||
| break | ||
|
|
||
| except Exception as e: | ||
| print(f"Error: {e}") | ||
| print(f"Line: {e.__traceback__.tb_lineno if e.__traceback__ else -1}") | ||
| break | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,53 +1,89 @@ | ||
| import torch | ||
| import numpy as np | ||
| from transformers import AutoProcessor, AutoModelForCTC | ||
| from transformers import ( | ||
| AutoProcessor, | ||
| AutoModelForCTC, | ||
| Wav2Vec2Processor, | ||
| Wav2Vec2ForCTC, | ||
| ) | ||
| from phoneme_utils import TIMESTAMPED_PHONES_T | ||
|
|
||
| SAMPLE_RATE = 16_000 | ||
|
|
||
| # Load Wav2Vec2 model | ||
| model_id = "KoelLabs/xlsr-english-01" | ||
| processor = AutoProcessor.from_pretrained(model_id) | ||
| model = AutoModelForCTC.from_pretrained(model_id) | ||
| processor: Wav2Vec2Processor = AutoProcessor.from_pretrained(model_id) | ||
| model: Wav2Vec2ForCTC = AutoModelForCTC.from_pretrained(model_id) | ||
| assert processor.feature_extractor.sampling_rate == SAMPLE_RATE | ||
|
|
||
|
|
||
| def transcribe_timestamped(audio: np.ndarray, time_offset=0.0) -> TIMESTAMPED_PHONES_T: | ||
| input_values = ( | ||
| processor( | ||
| audio, | ||
| sampling_rate=processor.feature_extractor.sampling_rate, | ||
| return_tensors="pt", | ||
| padding=True, | ||
| ) | ||
| .input_values.type(torch.float32) | ||
| .to(model.device) | ||
| def _calculate_cnn_window(model: Wav2Vec2ForCTC): | ||
| receptive_field = 1 | ||
| stride = 1 | ||
| for conv_layer in model.wav2vec2.feature_extractor.conv_layers: | ||
| assert hasattr(conv_layer, "conv") | ||
| conv = conv_layer.conv | ||
| assert isinstance(conv, torch.nn.Conv1d) | ||
| receptive_field += (conv.kernel_size[0] - 1) * stride | ||
| stride *= conv.stride[0] | ||
| return receptive_field, stride | ||
|
|
||
|
|
||
| RECEPTIVE_FIELD_SIZE, STRIDE_SIZE = _calculate_cnn_window(model) | ||
|
|
||
|
|
||
| def extract_features_only(audio: np.ndarray): | ||
| """Extract CNN features and project to encoder hidden size (transformer-ready).""" | ||
| # True raw sample count before any padding | ||
| raw_sample_count = int(np.asarray(audio).shape[-1]) | ||
|
|
||
| inputs = processor( | ||
| audio, | ||
| sampling_rate=SAMPLE_RATE, | ||
| return_tensors="pt", | ||
| padding="max_length", | ||
| max_length=RECEPTIVE_FIELD_SIZE, | ||
| ) | ||
| input_values = inputs.input_values.type(torch.float32).to(model.device) | ||
| with torch.no_grad(): | ||
| logits = model(input_values).logits | ||
| conv_feats = model.wav2vec2.feature_extractor(input_values) # (B, C, T') | ||
| conv_feats_t = conv_feats.transpose(1, 2) # (B, T', C) | ||
| # Project to hidden size for transformer; also returns normalized conv features | ||
| features, normed_conv_feats = model.wav2vec2.feature_projection(conv_feats_t) | ||
| # Return transformer-ready features and original (unpadded) input length in samples | ||
| return features, raw_sample_count | ||
|
|
||
| predicted_ids = torch.argmax(logits, dim=-1)[0].tolist() | ||
| duration_sec = input_values.shape[1] / processor.feature_extractor.sampling_rate | ||
|
|
||
| def run_transformer_on_features( | ||
| features: torch.Tensor, total_audio_samples: int, time_offset: float = 0.0 | ||
| ) -> TIMESTAMPED_PHONES_T: | ||
| """Run transformer and decode""" | ||
| # slowest step | ||
| with torch.no_grad(): | ||
| encoder_outputs = model.wav2vec2.encoder(features) | ||
arunasrivastava marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| logits = model.lm_head(encoder_outputs[0]) | ||
|
|
||
| predicted_ids = torch.argmax(logits, dim=-1)[0].tolist() | ||
| # Use original audio length in samples to compute duration | ||
| duration_sec = total_audio_samples / processor.feature_extractor.sampling_rate | ||
| ids_w_time = [ | ||
| (time_offset + i / len(predicted_ids) * duration_sec, _id) | ||
| for i, _id in enumerate(predicted_ids) | ||
| ] | ||
|
|
||
| current_phoneme_id = processor.tokenizer.pad_token_id | ||
| current_start_time = 0 | ||
| phonemes_with_time = [] | ||
| for time, _id in ids_w_time: | ||
| for timestamp, _id in ids_w_time: | ||
| if current_phoneme_id != _id: | ||
| if current_phoneme_id != processor.tokenizer.pad_token_id: | ||
| phonemes_with_time.append( | ||
| ( | ||
| processor.decode(current_phoneme_id), | ||
| current_start_time, | ||
| time, | ||
| timestamp, | ||
| ) | ||
| ) | ||
| current_start_time = time | ||
| current_phoneme_id = _id | ||
|
|
||
| current_start_time = timestamp | ||
| current_phoneme_id = _id | ||
| return phonemes_with_time | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll definitely want this to be adaptive. Not an immediate priority. Make sure you write good benchmarking code first so you can measure whether each change is an improvement. Then we can also combine with the VAD and local agreement optimizations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrote some suggested utils for benchmarking code here. Make sure to add accuracy, average over multiple samples, and create nice figures for a blog post