From d3bab3aee6f78b3f7dea178ba158023e4aae311f Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Thu, 8 Jan 2026 09:51:33 -0500 Subject: [PATCH 1/9] Initial parakeet timestamps impl --- examples/models/parakeet/README.md | 1 + .../models/parakeet/export_parakeet_tdt.py | 4 + examples/models/parakeet/main.cpp | 243 +++++++++++++++++- 3 files changed, 244 insertions(+), 4 deletions(-) diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index b27bc1f8a91..68eab0f27bf 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -71,3 +71,4 @@ From the executorch root directory: | `--audio_path` | Path to input audio file (.wav) | | `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | | `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) | +| `--timestamps` | Print word/segment timestamps (requires `window_stride` + `encoder_subsampling_factor` in model metadata) | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 92e32ca30bf..db911279252 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -351,6 +351,8 @@ def export_all(model): ) sample_rate = model.preprocessor._cfg.sample_rate + window_stride = float(model.preprocessor._cfg.window_stride) + encoder_subsampling_factor = int(getattr(model.encoder, "subsampling_factor", 1)) metadata = { "num_rnn_layers": num_layers, "pred_hidden": pred_hidden, @@ -358,6 +360,8 @@ def export_all(model): "vocab_size": model.tokenizer.vocab_size, "blank_id": model.tokenizer.vocab_size, "sample_rate": sample_rate, + "window_stride": window_stride, + "encoder_subsampling_factor": encoder_subsampling_factor, } return programs, metadata diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 026f3911a3d..b9c1f81d920 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -6,9 +6,11 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include #include +#include #include #include #include @@ -33,6 +35,10 @@ DEFINE_string( data_path, "", "Path to data file (.ptd) for delegate data (optional, required for CUDA)."); +DEFINE_bool( + timestamps, + false, + "Output word- and segment-level timestamps (requires model metadata)."); using ::executorch::extension::from_blob; using ::executorch::extension::Module; @@ -44,7 +50,179 @@ namespace { // TDT duration values const std::vector DURATIONS = {0, 1, 2, 3, 4}; -std::vector greedy_decode_executorch( +struct TokenTimestamp { + int64_t id; + int64_t start_offset; // encoder frame index + int64_t end_offset; // encoder frame index +}; + +struct WordTimestamp { + std::string text; + int64_t start_offset; + int64_t end_offset; + double start_sec; + double end_sec; +}; + +struct SegmentTimestamp { + std::string text; + int64_t start_offset; + int64_t end_offset; + double start_sec; + double end_sec; +}; + +bool is_ascii_punctuation_only(const std::string& s) { + if (s.empty()) { + return false; + } + for (unsigned char ch : s) { + if (!std::ispunct(ch)) { + return false; + } + } + return true; +} + +size_t ltrim_ascii_whitespace(const std::string& s) { + size_t i = 0; + while (i < s.size() && std::isspace(static_cast(s[i]))) { + i++; + } + return i; +} + +std::vector tokens_to_word_timestamps( + const std::vector& tokens, + tokenizers::Tokenizer* tokenizer, + double seconds_per_encoder_frame) { + std::vector words; + if (!tokenizer || tokens.empty()) { + return words; + } + + uint64_t prev_token = 0; + + std::string current_word; + int64_t current_start_offset = 0; + int64_t current_end_offset = 0; + int64_t prev_end_offset = 0; + bool has_prev_end_offset = false; + + auto emit_word = [&]() { + if (current_word.empty()) { + return; + } + double start_sec = -1.0; + double end_sec = -1.0; + if (seconds_per_encoder_frame > 0.0) { + start_sec = seconds_per_encoder_frame * current_start_offset; + end_sec = seconds_per_encoder_frame * current_end_offset; + } + words.push_back(WordTimestamp{ + current_word, current_start_offset, current_end_offset, start_sec, end_sec}); + current_word.clear(); + }; + + for (const auto& token_ts : tokens) { + uint64_t token = static_cast(token_ts.id); + auto decode_result = tokenizer->decode(prev_token, token); + prev_token = token; + + std::string piece = decode_result.ok() ? decode_result.get() : std::string(); + size_t non_ws = ltrim_ascii_whitespace(piece); + bool had_leading_ws = non_ws > 0; + std::string trimmed_piece = piece.substr(non_ws); + + if (trimmed_piece.empty()) { + continue; + } + + TokenTimestamp adjusted = token_ts; + const bool is_punct = is_ascii_punctuation_only(trimmed_piece); + if (is_punct && has_prev_end_offset) { + // TDT can sometimes emit punctuation long after the preceding word. Pin + // punctuation timing to the previous token end. + adjusted.start_offset = prev_end_offset; + adjusted.end_offset = prev_end_offset; + } + + if (current_word.empty()) { + current_word = trimmed_piece; + current_start_offset = adjusted.start_offset; + current_end_offset = adjusted.end_offset; + } else if (had_leading_ws && !is_punct) { + emit_word(); + current_word = trimmed_piece; + current_start_offset = adjusted.start_offset; + current_end_offset = adjusted.end_offset; + } else { + current_word += trimmed_piece; + current_end_offset = adjusted.end_offset; + } + + prev_end_offset = adjusted.end_offset; + has_prev_end_offset = true; + } + + emit_word(); + return words; +} + +std::vector words_to_segment_timestamps( + const std::vector& words, + double seconds_per_encoder_frame) { + std::vector segments; + if (words.empty()) { + return segments; + } + + std::string current_segment; + int64_t segment_start_offset = 0; + int64_t segment_end_offset = 0; + bool has_segment = false; + + auto emit_segment = [&]() { + if (!has_segment || current_segment.empty()) { + return; + } + double start_sec = -1.0; + double end_sec = -1.0; + if (seconds_per_encoder_frame > 0.0) { + start_sec = seconds_per_encoder_frame * segment_start_offset; + end_sec = seconds_per_encoder_frame * segment_end_offset; + } + segments.push_back(SegmentTimestamp{ + current_segment, segment_start_offset, segment_end_offset, start_sec, end_sec}); + current_segment.clear(); + has_segment = false; + }; + + for (const auto& word : words) { + if (!has_segment) { + has_segment = true; + current_segment = word.text; + segment_start_offset = word.start_offset; + segment_end_offset = word.end_offset; + } else { + current_segment += " "; + current_segment += word.text; + segment_end_offset = word.end_offset; + } + + if (!word.text.empty()) { + char last = word.text.back(); + if (last == '.' || last == '!' || last == '?') { + emit_segment(); + } + } + } + + emit_segment(); + return segments; +} + +std::vector greedy_decode_executorch( Module& model, const ::executorch::aten::Tensor& encoder_output, int64_t encoder_len, @@ -53,7 +231,7 @@ std::vector greedy_decode_executorch( int64_t num_rnn_layers = 2, int64_t pred_hidden = 640, int64_t max_symbols_per_step = 10) { - std::vector hypothesis; + std::vector hypothesis; int64_t num_token_classes = vocab_size + 1; // Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim] @@ -208,7 +386,7 @@ std::vector greedy_decode_executorch( t += std::max(dur, (int64_t)1); symbols_on_frame = 0; } else { - hypothesis.push_back(k); + hypothesis.push_back(TokenTimestamp{k, t, t + dur}); // Update decoder state std::vector token_data = {k}; @@ -431,9 +609,66 @@ int main(int argc, char** argv) { } // Convert tokens to text - std::string text = tokens_to_text(tokens, tokenizer.get()); + std::vector token_ids; + token_ids.reserve(tokens.size()); + for (const auto& t : tokens) { + token_ids.push_back(t.id); + } + std::string text = tokens_to_text(token_ids, tokenizer.get()); std::cout << "Transcription tokens: " << text << std::endl; + if (FLAGS_timestamps) { + std::vector<::executorch::runtime::EValue> empty_inputs; + auto window_stride_result = model->execute("window_stride", empty_inputs); + auto subsampling_factor_result = + model->execute("encoder_subsampling_factor", empty_inputs); + + double seconds_per_encoder_frame = -1.0; + if (window_stride_result.ok() && subsampling_factor_result.ok()) { + double window_stride = window_stride_result.get()[0].toDouble(); + int64_t encoder_subsampling_factor = + subsampling_factor_result.get()[0].toInt(); + seconds_per_encoder_frame = window_stride * encoder_subsampling_factor; + ET_LOG( + Info, + "Timestamp metadata: window_stride=%f, encoder_subsampling_factor=%lld, seconds_per_encoder_frame=%f", + window_stride, + static_cast(encoder_subsampling_factor), + seconds_per_encoder_frame); + } else { + ET_LOG( + Error, + "Timestamps requested but model metadata is missing. Re-export the model with constant_methods for window_stride and encoder_subsampling_factor."); + } + + auto words = tokens_to_word_timestamps( + tokens, tokenizer.get(), seconds_per_encoder_frame); + auto segments = + words_to_segment_timestamps(words, seconds_per_encoder_frame); + + std::cout << std::fixed << std::setprecision(2); + + std::cout << "\nWord timestamps:\n"; + for (const auto& w : words) { + if (seconds_per_encoder_frame > 0.0) { + std::cout << "[" << w.start_sec << ", " << w.end_sec << "] "; + } else { + std::cout << "[" << w.start_offset << ", " << w.end_offset << "] "; + } + std::cout << w.text << "\n"; + } + + std::cout << "\nSegment timestamps:\n"; + for (const auto& s : segments) { + if (seconds_per_encoder_frame > 0.0) { + std::cout << "[" << s.start_sec << ", " << s.end_sec << "] "; + } else { + std::cout << "[" << s.start_offset << ", " << s.end_offset << "] "; + } + std::cout << s.text << "\n"; + } + } + ET_LOG(Info, "Done!"); return 0; } From 29c30a84364a28896861ab5d4db8417006682340 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Thu, 8 Jan 2026 13:39:56 -0500 Subject: [PATCH 2/9] dev annotation --- examples/models/parakeet/main.cpp | 111 +++++++++++++++++++++++++++++- 1 file changed, 109 insertions(+), 2 deletions(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index b9c1f81d920..fe2c0fb62fb 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -49,12 +49,32 @@ namespace { // TDT duration values const std::vector DURATIONS = {0, 1, 2, 3, 4}; +// NeMo: TDT maps a duration-class argmax -> "skip" (advance) in encoder frames. +// - Viable duration choices come from loss/config into decoding_cfg.durations: +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/models/rnnt_models.py#L230-L238 +// - Greedy TDT decoding uses: skip = self.durations[d_k] +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2665-L2669 +// Divergence: we hardcode {0,1,2,3,4} here (matches current Parakeet TDT +// export). If the exported model's duration set changes, this must be updated +// to match or timestamps/decoding will drift. struct TokenTimestamp { int64_t id; int64_t start_offset; // encoder frame index int64_t end_offset; // encoder frame index }; +// NeMo: TDT timing is represented on the Hypothesis as two parallel lists: +// - `Hypothesis.timestamp` holds the encoder frame index where each non-blank +// token was emitted. +// - `Hypothesis.token_duration` holds the predicted duration/skip for that token. +// These are later converted to `{start_offset, end_offset}` via +// RNNTDecoding._compute_offsets_tdt(). +// - Where NeMo records `timestamp` + `token_duration` during greedy TDT decode: +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2604-L2693 +// - Where NeMo converts them into offsets: +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1128-L1156 +// Divergence: this ExecuTorch example stores `{start_offset,end_offset}` directly +// as it decodes. We don't preserve NeMo's intermediate `timestep` list. struct WordTimestamp { std::string text; @@ -63,6 +83,13 @@ struct WordTimestamp { double start_sec; double end_sec; }; +// NeMo: word-level timestamps are built from per-token offsets with +// `get_words_offsets()` and then converted from offsets->seconds by +// `process_timestamp_outputs()`. +// - Word grouping: https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224 +// - Offsets->seconds: https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L428-L479 +// Divergence: we only implement word + segment timestamps in this example (no +// char/subword timestamps). struct SegmentTimestamp { std::string text; @@ -71,8 +98,20 @@ struct SegmentTimestamp { double start_sec; double end_sec; }; +// NeMo: segment-level timestamps are built from word offsets with +// `get_segment_offsets()`. +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327 +// Divergence: we only segment on terminal punctuation (.,!,?) and do not +// implement NeMo's optional `segment_gap_threshold` behavior. bool is_ascii_punctuation_only(const std::string& s) { + // NeMo: TDT punctuation timestamp refinement is applied when a punctuation + // token appears long after the previous token; NeMo "pins" punctuation to the + // previous token's end offset. + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 + // Divergence: NeMo checks membership in a model-specific `supported_punctuation` + // set (can include non-ASCII). Here we approximate by checking ASCII + // `std::ispunct()` on bytes. if (s.empty()) { return false; } @@ -85,6 +124,11 @@ bool is_ascii_punctuation_only(const std::string& s) { } size_t ltrim_ascii_whitespace(const std::string& s) { + // NeMo: word boundaries for BPE/WPE are detected via tokenizer-type-specific + // logic in `get_words_offsets()` (word delimiter char, special markers, etc). + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99 + // Divergence: we treat leading *ASCII whitespace* in the decoded piece as the + // only word boundary signal. size_t i = 0; while (i < s.size() && std::isspace(static_cast(s[i]))) { i++; @@ -96,6 +140,15 @@ std::vector tokens_to_word_timestamps( const std::vector& tokens, tokenizers::Tokenizer* tokenizer, double seconds_per_encoder_frame) { + // NeMo reference for word grouping (subword/char offsets -> word offsets): + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224 + // + // Divergences from NeMo: + // - NeMo builds words from decoded token offsets (and handles tokenizer types); + // here we build words by incrementally decoding each token and using leading + // ASCII whitespace as the boundary. + // - NeMo returns `Hypothesis.timestamp['char']` in addition to word/segment; + // this example does not generate char/subword timestamps. std::vector words; if (!tokenizer || tokens.empty()) { return words; @@ -143,6 +196,10 @@ std::vector tokens_to_word_timestamps( if (is_punct && has_prev_end_offset) { // TDT can sometimes emit punctuation long after the preceding word. Pin // punctuation timing to the previous token end. + // NeMo: RNNTDecoding._refine_timestamps_tdt() applies the same correction: + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 + // Divergence: NeMo consults `supported_punctuation` from the model; here we + // approximate punctuation detection (ASCII-only) via `is_ascii_punctuation_only()`. adjusted.start_offset = prev_end_offset; adjusted.end_offset = prev_end_offset; } @@ -152,6 +209,11 @@ std::vector tokens_to_word_timestamps( current_start_offset = adjusted.start_offset; current_end_offset = adjusted.end_offset; } else if (had_leading_ws && !is_punct) { + // NeMo: `get_words_offsets()` decides when a new word starts using + // tokenizer-aware rules (delimiter markers, WPE "##" prefixes, etc): + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99 + // Divergence: our boundary rule is strictly "decoded piece had leading + // ASCII whitespace and is not punctuation". emit_word(); current_word = trimmed_piece; current_start_offset = adjusted.start_offset; @@ -172,6 +234,10 @@ std::vector tokens_to_word_timestamps( std::vector words_to_segment_timestamps( const std::vector& words, double seconds_per_encoder_frame) { + // NeMo reference for segment grouping (word offsets -> segment offsets): + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327 + // Divergence: we only segment on terminal punctuation (.,!,?) and do not + // implement NeMo's optional `segment_gap_threshold` splitting. std::vector segments; if (words.empty()) { return segments; @@ -213,6 +279,8 @@ std::vector words_to_segment_timestamps( if (!word.text.empty()) { char last = word.text.back(); if (last == '.' || last == '!' || last == '?') { + // NeMo: segment delimiters are configurable (default includes '.', '?', '!'): + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L287-L296 emit_segment(); } } @@ -231,6 +299,19 @@ std::vector greedy_decode_executorch( int64_t num_rnn_layers = 2, int64_t pred_hidden = 640, int64_t max_symbols_per_step = 10) { + // NeMo reference for greedy TDT decoding (where the *token timing* originates): + // - Core greedy TDT loop (token argmax + duration argmax + skip advance): + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2627-L2717 + // - NeMo records per-token `timestamp` + `token_duration` here: + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2684-L2693 + // + // Divergences from NeMo: + // - We implement a single-item (B=1) decoder loop directly over ExecuTorch + // exported methods (`joint_project_*`, `decoder_predict`, `joint`). + // - We take argmax directly on raw logits (no log_softmax); this matches + // NeMo's argmax choice but we do not compute scores/confidence. + // - NeMo's loop structure uses an explicit inner loop for `skip==0` label + // looping; here we emulate it with `dur==0` and `symbols_on_frame`. std::vector hypothesis; int64_t num_token_classes = vocab_size + 1; @@ -286,9 +367,9 @@ std::vector greedy_decode_executorch( // Prime the prediction network state with SOS (= blank_id) to match NeMo TDT // greedy label-looping decoding behavior: // - SOS is defined as blank: - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c980b70cecc184fa8a083a9c3ddb87f905e/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py#L250 + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py#L250 // - Predictor priming with SOS: - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c980b70cecc184fa8a083a9c3ddb87f905e/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py#L363-L368 + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py#L363-L368 std::vector sos_token_data = {blank_id}; auto sos_token = from_blob( sos_token_data.data(), {1, 1}, ::executorch::aten::ScalarType::Long); @@ -383,9 +464,20 @@ std::vector greedy_decode_executorch( int64_t dur = DURATIONS[dur_idx]; if (k == blank_id) { + // NeMo: if blank is emitted with duration=0, it forces progress to avoid + // infinite loops (skip==0 -> skip=1): + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2700-L2704 + // Divergence: NeMo advances `time_idx += skip` first and patches `skip` + // after the inner loop; here we apply `max(dur,1)` immediately in the + // blank branch. t += std::max(dur, (int64_t)1); symbols_on_frame = 0; } else { + // NeMo: emits token at `time_idx` and stores duration separately: + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2684-L2693 + // NeMo later converts (timestamp, token_duration) -> (start_offset, end_offset): + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1128-L1156 + // Divergence: we store (start_offset=t, end_offset=t+dur) directly. hypothesis.push_back(TokenTimestamp{k, t, t + dur}); // Update decoder state @@ -432,6 +524,12 @@ std::vector greedy_decode_executorch( t += dur; if (dur == 0) { + // NeMo: label looping occurs when `skip == 0` (stay on same encoder frame) + // until a non-zero skip is predicted, capped by `max_symbols_per_step`: + // - need_loop = (skip == 0): + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2695-L2699 + // - force progress after max symbols: + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2715-L2716 symbols_on_frame++; if (symbols_on_frame >= max_symbols_per_step) { t++; @@ -618,6 +716,15 @@ int main(int argc, char** argv) { std::cout << "Transcription tokens: " << text << std::endl; if (FLAGS_timestamps) { + // NeMo: offset->seconds conversion uses + // start = start_offset * window_stride * subsampling_factor + // end = end_offset * window_stride * subsampling_factor + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L428-L479 + // + // Divergence: NeMo reads `window_stride` from the preprocessor config and + // `subsampling_factor` from the encoder module. In ExecuTorch we require + // these values to be exported as `constant_methods` (`window_stride` and + // `encoder_subsampling_factor`). If unavailable, we print raw offsets. std::vector<::executorch::runtime::EValue> empty_inputs; auto window_stride_result = model->execute("window_stride", empty_inputs); auto subsampling_factor_result = From 565e78cfe39f381ae7d4e8b804d624c95ea0ce54 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Thu, 8 Jan 2026 15:52:33 -0500 Subject: [PATCH 3/9] Subword/token timestamps --- examples/models/parakeet/main.cpp | 81 +++++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 4 deletions(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index fe2c0fb62fb..f094386f184 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -38,7 +38,7 @@ DEFINE_string( DEFINE_bool( timestamps, false, - "Output word- and segment-level timestamps (requires model metadata)."); + "Output subword-, word-, and segment-level timestamps (requires model metadata)."); using ::executorch::extension::from_blob; using ::executorch::extension::Module; @@ -76,6 +76,14 @@ struct TokenTimestamp { // Divergence: this ExecuTorch example stores `{start_offset,end_offset}` directly // as it decodes. We don't preserve NeMo's intermediate `timestep` list. +struct SubwordTimestamp { + std::string text; + int64_t start_offset; + int64_t end_offset; + double start_sec; + double end_sec; +}; + struct WordTimestamp { std::string text; int64_t start_offset; @@ -88,8 +96,8 @@ struct WordTimestamp { // `process_timestamp_outputs()`. // - Word grouping: https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224 // - Offsets->seconds: https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L428-L479 -// Divergence: we only implement word + segment timestamps in this example (no -// char/subword timestamps). +// Note: NeMo's `timestamp['char']` for Parakeet is per-subword token offsets +// (not true per-character). struct SegmentTimestamp { std::string text; @@ -136,6 +144,58 @@ size_t ltrim_ascii_whitespace(const std::string& s) { return i; } +std::vector tokens_to_subword_timestamps( + const std::vector& tokens, + tokenizers::Tokenizer* tokenizer, + double seconds_per_encoder_frame) { + // NeMo reference: TDT per-token "char" timestamps are computed in + // `compute_rnnt_timestamps()` via `_compute_offsets_tdt()` and + // `_refine_timestamps_tdt()`: + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L991 + // NeMo: "char" timestamps for Parakeet-TDT correspond to per-subword token + // offsets, with a TDT punctuation refinement step. + std::vector subwords; + if (!tokenizer) { + return subwords; + } + + const uint64_t bos_token = tokenizer->bos_tok(); + int64_t prev_end_offset = 0; + bool has_prev_end_offset = false; + + for (const auto& token_ts : tokens) { + uint64_t token = static_cast(token_ts.id); + auto decode_result = tokenizer->decode(bos_token, token); + + std::string piece = decode_result.ok() ? decode_result.get() : std::string(); + + TokenTimestamp adjusted = token_ts; + size_t non_ws = ltrim_ascii_whitespace(piece); + std::string trimmed_piece = piece.substr(non_ws); + + const bool is_punct = is_ascii_punctuation_only(trimmed_piece); + if (is_punct && has_prev_end_offset) { + adjusted.start_offset = prev_end_offset; + adjusted.end_offset = prev_end_offset; + } + + double start_sec = -1.0; + double end_sec = -1.0; + if (seconds_per_encoder_frame > 0.0) { + start_sec = seconds_per_encoder_frame * adjusted.start_offset; + end_sec = seconds_per_encoder_frame * adjusted.end_offset; + } + + subwords.push_back(SubwordTimestamp{ + piece, adjusted.start_offset, adjusted.end_offset, start_sec, end_sec}); + + prev_end_offset = adjusted.end_offset; + has_prev_end_offset = true; + } + + return subwords; +} + std::vector tokens_to_word_timestamps( const std::vector& tokens, tokenizers::Tokenizer* tokenizer, @@ -148,7 +208,7 @@ std::vector tokens_to_word_timestamps( // here we build words by incrementally decoding each token and using leading // ASCII whitespace as the boundary. // - NeMo returns `Hypothesis.timestamp['char']` in addition to word/segment; - // this example does not generate char/subword timestamps. + // this example also emits per-subword timestamps. std::vector words; if (!tokenizer || tokens.empty()) { return words; @@ -755,6 +815,19 @@ int main(int argc, char** argv) { std::cout << std::fixed << std::setprecision(2); + auto subwords = tokens_to_subword_timestamps( + tokens, tokenizer.get(), seconds_per_encoder_frame); + + std::cout << "\nSubword timestamps:\n"; + for (const auto& sw : subwords) { + if (seconds_per_encoder_frame > 0.0) { + std::cout << "[" << sw.start_sec << ", " << sw.end_sec << "] "; + } else { + std::cout << "[" << sw.start_offset << ", " << sw.end_offset << "] "; + } + std::cout << sw.text << "\n"; + } + std::cout << "\nWord timestamps:\n"; for (const auto& w : words) { if (seconds_per_encoder_frame > 0.0) { From 9d072c555c9f1963fb7b3b7872ce0d37940cffc9 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Thu, 8 Jan 2026 16:17:53 -0500 Subject: [PATCH 4/9] Cleanup comments --- examples/models/parakeet/main.cpp | 131 ++++-------------------------- 1 file changed, 17 insertions(+), 114 deletions(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index f094386f184..02940a68584 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -47,34 +47,15 @@ using ::executorch::runtime::EValue; namespace { -// TDT duration values +// TDT duration values (hardcoded for simplicity, comes from model config in NeMo implementation) +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/models/rnnt_models.py#L230-L238 const std::vector DURATIONS = {0, 1, 2, 3, 4}; -// NeMo: TDT maps a duration-class argmax -> "skip" (advance) in encoder frames. -// - Viable duration choices come from loss/config into decoding_cfg.durations: -// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/models/rnnt_models.py#L230-L238 -// - Greedy TDT decoding uses: skip = self.durations[d_k] -// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2665-L2669 -// Divergence: we hardcode {0,1,2,3,4} here (matches current Parakeet TDT -// export). If the exported model's duration set changes, this must be updated -// to match or timestamps/decoding will drift. struct TokenTimestamp { int64_t id; int64_t start_offset; // encoder frame index int64_t end_offset; // encoder frame index }; -// NeMo: TDT timing is represented on the Hypothesis as two parallel lists: -// - `Hypothesis.timestamp` holds the encoder frame index where each non-blank -// token was emitted. -// - `Hypothesis.token_duration` holds the predicted duration/skip for that token. -// These are later converted to `{start_offset, end_offset}` via -// RNNTDecoding._compute_offsets_tdt(). -// - Where NeMo records `timestamp` + `token_duration` during greedy TDT decode: -// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2604-L2693 -// - Where NeMo converts them into offsets: -// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1128-L1156 -// Divergence: this ExecuTorch example stores `{start_offset,end_offset}` directly -// as it decodes. We don't preserve NeMo's intermediate `timestep` list. struct SubwordTimestamp { std::string text; @@ -91,13 +72,6 @@ struct WordTimestamp { double start_sec; double end_sec; }; -// NeMo: word-level timestamps are built from per-token offsets with -// `get_words_offsets()` and then converted from offsets->seconds by -// `process_timestamp_outputs()`. -// - Word grouping: https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224 -// - Offsets->seconds: https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L428-L479 -// Note: NeMo's `timestamp['char']` for Parakeet is per-subword token offsets -// (not true per-character). struct SegmentTimestamp { std::string text; @@ -106,20 +80,8 @@ struct SegmentTimestamp { double start_sec; double end_sec; }; -// NeMo: segment-level timestamps are built from word offsets with -// `get_segment_offsets()`. -// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327 -// Divergence: we only segment on terminal punctuation (.,!,?) and do not -// implement NeMo's optional `segment_gap_threshold` behavior. bool is_ascii_punctuation_only(const std::string& s) { - // NeMo: TDT punctuation timestamp refinement is applied when a punctuation - // token appears long after the previous token; NeMo "pins" punctuation to the - // previous token's end offset. - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 - // Divergence: NeMo checks membership in a model-specific `supported_punctuation` - // set (can include non-ASCII). Here we approximate by checking ASCII - // `std::ispunct()` on bytes. if (s.empty()) { return false; } @@ -132,11 +94,6 @@ bool is_ascii_punctuation_only(const std::string& s) { } size_t ltrim_ascii_whitespace(const std::string& s) { - // NeMo: word boundaries for BPE/WPE are detected via tokenizer-type-specific - // logic in `get_words_offsets()` (word delimiter char, special markers, etc). - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99 - // Divergence: we treat leading *ASCII whitespace* in the decoded piece as the - // only word boundary signal. size_t i = 0; while (i < s.size() && std::isspace(static_cast(s[i]))) { i++; @@ -148,12 +105,8 @@ std::vector tokens_to_subword_timestamps( const std::vector& tokens, tokenizers::Tokenizer* tokenizer, double seconds_per_encoder_frame) { - // NeMo reference: TDT per-token "char" timestamps are computed in - // `compute_rnnt_timestamps()` via `_compute_offsets_tdt()` and - // `_refine_timestamps_tdt()`: - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L991 - // NeMo: "char" timestamps for Parakeet-TDT correspond to per-subword token - // offsets, with a TDT punctuation refinement step. + // NeMo reference of TDT per-token "char" timestamp computation: + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L991 std::vector subwords; if (!tokenizer) { return subwords; @@ -201,14 +154,7 @@ std::vector tokens_to_word_timestamps( tokenizers::Tokenizer* tokenizer, double seconds_per_encoder_frame) { // NeMo reference for word grouping (subword/char offsets -> word offsets): - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224 - // - // Divergences from NeMo: - // - NeMo builds words from decoded token offsets (and handles tokenizer types); - // here we build words by incrementally decoding each token and using leading - // ASCII whitespace as the boundary. - // - NeMo returns `Hypothesis.timestamp['char']` in addition to word/segment; - // this example also emits per-subword timestamps. + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224 std::vector words; if (!tokenizer || tokens.empty()) { return words; @@ -251,15 +197,14 @@ std::vector tokens_to_word_timestamps( continue; } + // TDT sometimes emits punctuation long after preceding token. Thus, pin to previous token. + // NeMo applies the same correction: + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 + // Divergence: NeMo consults `supported_punctuation` from the model; here we + // approximate punctuation detection (ASCII-only) via `is_ascii_punctuation_only()`. TokenTimestamp adjusted = token_ts; const bool is_punct = is_ascii_punctuation_only(trimmed_piece); if (is_punct && has_prev_end_offset) { - // TDT can sometimes emit punctuation long after the preceding word. Pin - // punctuation timing to the previous token end. - // NeMo: RNNTDecoding._refine_timestamps_tdt() applies the same correction: - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 - // Divergence: NeMo consults `supported_punctuation` from the model; here we - // approximate punctuation detection (ASCII-only) via `is_ascii_punctuation_only()`. adjusted.start_offset = prev_end_offset; adjusted.end_offset = prev_end_offset; } @@ -269,11 +214,9 @@ std::vector tokens_to_word_timestamps( current_start_offset = adjusted.start_offset; current_end_offset = adjusted.end_offset; } else if (had_leading_ws && !is_punct) { - // NeMo: `get_words_offsets()` decides when a new word starts using - // tokenizer-aware rules (delimiter markers, WPE "##" prefixes, etc): - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99 - // Divergence: our boundary rule is strictly "decoded piece had leading - // ASCII whitespace and is not punctuation". + // NeMo builds words from decoded token offsets w/ tokenizer-aware rules: + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99 + // Here we simplify, building words per-token and using leading whitespace as the boundary. emit_word(); current_word = trimmed_piece; current_start_offset = adjusted.start_offset; @@ -295,9 +238,7 @@ std::vector words_to_segment_timestamps( const std::vector& words, double seconds_per_encoder_frame) { // NeMo reference for segment grouping (word offsets -> segment offsets): - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327 - // Divergence: we only segment on terminal punctuation (.,!,?) and do not - // implement NeMo's optional `segment_gap_threshold` splitting. + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327 std::vector segments; if (words.empty()) { return segments; @@ -338,9 +279,9 @@ std::vector words_to_segment_timestamps( if (!word.text.empty()) { char last = word.text.back(); + // NeMo Divergence: we only segment on terminal punctuation (.,!,?) rather than configurable + // segment_delimiter_tokens. Also no `segment_gap_threshold` splitting. if (last == '.' || last == '!' || last == '?') { - // NeMo: segment delimiters are configurable (default includes '.', '?', '!'): - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L287-L296 emit_segment(); } } @@ -359,19 +300,6 @@ std::vector greedy_decode_executorch( int64_t num_rnn_layers = 2, int64_t pred_hidden = 640, int64_t max_symbols_per_step = 10) { - // NeMo reference for greedy TDT decoding (where the *token timing* originates): - // - Core greedy TDT loop (token argmax + duration argmax + skip advance): - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2627-L2717 - // - NeMo records per-token `timestamp` + `token_duration` here: - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2684-L2693 - // - // Divergences from NeMo: - // - We implement a single-item (B=1) decoder loop directly over ExecuTorch - // exported methods (`joint_project_*`, `decoder_predict`, `joint`). - // - We take argmax directly on raw logits (no log_softmax); this matches - // NeMo's argmax choice but we do not compute scores/confidence. - // - NeMo's loop structure uses an explicit inner loop for `skip==0` label - // looping; here we emulate it with `dur==0` and `symbols_on_frame`. std::vector hypothesis; int64_t num_token_classes = vocab_size + 1; @@ -524,20 +452,9 @@ std::vector greedy_decode_executorch( int64_t dur = DURATIONS[dur_idx]; if (k == blank_id) { - // NeMo: if blank is emitted with duration=0, it forces progress to avoid - // infinite loops (skip==0 -> skip=1): - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2700-L2704 - // Divergence: NeMo advances `time_idx += skip` first and patches `skip` - // after the inner loop; here we apply `max(dur,1)` immediately in the - // blank branch. t += std::max(dur, (int64_t)1); symbols_on_frame = 0; } else { - // NeMo: emits token at `time_idx` and stores duration separately: - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2684-L2693 - // NeMo later converts (timestamp, token_duration) -> (start_offset, end_offset): - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1128-L1156 - // Divergence: we store (start_offset=t, end_offset=t+dur) directly. hypothesis.push_back(TokenTimestamp{k, t, t + dur}); // Update decoder state @@ -584,12 +501,6 @@ std::vector greedy_decode_executorch( t += dur; if (dur == 0) { - // NeMo: label looping occurs when `skip == 0` (stay on same encoder frame) - // until a non-zero skip is predicted, capped by `max_symbols_per_step`: - // - need_loop = (skip == 0): - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2695-L2699 - // - force progress after max symbols: - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py#L2715-L2716 symbols_on_frame++; if (symbols_on_frame >= max_symbols_per_step) { t++; @@ -776,15 +687,6 @@ int main(int argc, char** argv) { std::cout << "Transcription tokens: " << text << std::endl; if (FLAGS_timestamps) { - // NeMo: offset->seconds conversion uses - // start = start_offset * window_stride * subsampling_factor - // end = end_offset * window_stride * subsampling_factor - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L428-L479 - // - // Divergence: NeMo reads `window_stride` from the preprocessor config and - // `subsampling_factor` from the encoder module. In ExecuTorch we require - // these values to be exported as `constant_methods` (`window_stride` and - // `encoder_subsampling_factor`). If unavailable, we print raw offsets. std::vector<::executorch::runtime::EValue> empty_inputs; auto window_stride_result = model->execute("window_stride", empty_inputs); auto subsampling_factor_result = @@ -806,6 +708,7 @@ int main(int argc, char** argv) { ET_LOG( Error, "Timestamps requested but model metadata is missing. Re-export the model with constant_methods for window_stride and encoder_subsampling_factor."); + return 1; } auto words = tokens_to_word_timestamps( From 4c5402b3fb217f84f26e246a0f87f25c52f88eb9 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Thu, 8 Jan 2026 16:30:40 -0500 Subject: [PATCH 5/9] format --- .../models/parakeet/export_parakeet_tdt.py | 2 +- examples/models/parakeet/main.cpp | 54 ++++++++++++------- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index db911279252..3990e1165d2 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -352,7 +352,7 @@ def export_all(model): sample_rate = model.preprocessor._cfg.sample_rate window_stride = float(model.preprocessor._cfg.window_stride) - encoder_subsampling_factor = int(getattr(model.encoder, "subsampling_factor", 1)) + encoder_subsampling_factor = int(getattr(model.encoder, "subsampling_factor", 8)) metadata = { "num_rnn_layers": num_layers, "pred_hidden": pred_hidden, diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 02940a68584..225f7588927 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -47,7 +47,7 @@ using ::executorch::runtime::EValue; namespace { -// TDT duration values (hardcoded for simplicity, comes from model config in NeMo implementation) +// TDT duration values (Comes from model config in NeMo implementation) // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/models/rnnt_models.py#L230-L238 const std::vector DURATIONS = {0, 1, 2, 3, 4}; @@ -120,7 +120,8 @@ std::vector tokens_to_subword_timestamps( uint64_t token = static_cast(token_ts.id); auto decode_result = tokenizer->decode(bos_token, token); - std::string piece = decode_result.ok() ? decode_result.get() : std::string(); + std::string piece = + decode_result.ok() ? decode_result.get() : std::string(); TokenTimestamp adjusted = token_ts; size_t non_ws = ltrim_ascii_whitespace(piece); @@ -139,8 +140,13 @@ std::vector tokens_to_subword_timestamps( end_sec = seconds_per_encoder_frame * adjusted.end_offset; } - subwords.push_back(SubwordTimestamp{ - piece, adjusted.start_offset, adjusted.end_offset, start_sec, end_sec}); + subwords.push_back( + SubwordTimestamp{ + piece, + adjusted.start_offset, + adjusted.end_offset, + start_sec, + end_sec}); prev_end_offset = adjusted.end_offset; has_prev_end_offset = true; @@ -178,8 +184,13 @@ std::vector tokens_to_word_timestamps( start_sec = seconds_per_encoder_frame * current_start_offset; end_sec = seconds_per_encoder_frame * current_end_offset; } - words.push_back(WordTimestamp{ - current_word, current_start_offset, current_end_offset, start_sec, end_sec}); + words.push_back( + WordTimestamp{ + current_word, + current_start_offset, + current_end_offset, + start_sec, + end_sec}); current_word.clear(); }; @@ -188,7 +199,8 @@ std::vector tokens_to_word_timestamps( auto decode_result = tokenizer->decode(prev_token, token); prev_token = token; - std::string piece = decode_result.ok() ? decode_result.get() : std::string(); + std::string piece = + decode_result.ok() ? decode_result.get() : std::string(); size_t non_ws = ltrim_ascii_whitespace(piece); bool had_leading_ws = non_ws > 0; std::string trimmed_piece = piece.substr(non_ws); @@ -197,11 +209,11 @@ std::vector tokens_to_word_timestamps( continue; } - // TDT sometimes emits punctuation long after preceding token. Thus, pin to previous token. - // NeMo applies the same correction: + // TDT sometimes emits punctuation long after preceding token. Thus, pin to + // previous token. NeMo applies the same correction: // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 // Divergence: NeMo consults `supported_punctuation` from the model; here we - // approximate punctuation detection (ASCII-only) via `is_ascii_punctuation_only()`. + // approximate with `is_ascii_punctuation_only()`. TokenTimestamp adjusted = token_ts; const bool is_punct = is_ascii_punctuation_only(trimmed_piece); if (is_punct && has_prev_end_offset) { @@ -216,7 +228,7 @@ std::vector tokens_to_word_timestamps( } else if (had_leading_ws && !is_punct) { // NeMo builds words from decoded token offsets w/ tokenizer-aware rules: // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99 - // Here we simplify, building words per-token and using leading whitespace as the boundary. + // Here, just build words per-token and separate by leading whitespace emit_word(); current_word = trimmed_piece; current_start_offset = adjusted.start_offset; @@ -259,8 +271,13 @@ std::vector words_to_segment_timestamps( start_sec = seconds_per_encoder_frame * segment_start_offset; end_sec = seconds_per_encoder_frame * segment_end_offset; } - segments.push_back(SegmentTimestamp{ - current_segment, segment_start_offset, segment_end_offset, start_sec, end_sec}); + segments.push_back( + SegmentTimestamp{ + current_segment, + segment_start_offset, + segment_end_offset, + start_sec, + end_sec}); current_segment.clear(); has_segment = false; }; @@ -279,8 +296,9 @@ std::vector words_to_segment_timestamps( if (!word.text.empty()) { char last = word.text.back(); - // NeMo Divergence: we only segment on terminal punctuation (.,!,?) rather than configurable - // segment_delimiter_tokens. Also no `segment_gap_threshold` splitting. + // NeMo Divergence: we only segment on terminal punctuation (.,!,?) rather + // than configurable segment_delimiter_tokens. Also no + // `segment_gap_threshold` splitting. if (last == '.' || last == '!' || last == '?') { emit_segment(); } @@ -711,16 +729,14 @@ int main(int argc, char** argv) { return 1; } + auto subwords = tokens_to_subword_timestamps( + tokens, tokenizer.get(), seconds_per_encoder_frame); auto words = tokens_to_word_timestamps( tokens, tokenizer.get(), seconds_per_encoder_frame); auto segments = words_to_segment_timestamps(words, seconds_per_encoder_frame); std::cout << std::fixed << std::setprecision(2); - - auto subwords = tokens_to_subword_timestamps( - tokens, tokenizer.get(), seconds_per_encoder_frame); - std::cout << "\nSubword timestamps:\n"; for (const auto& sw : subwords) { if (seconds_per_encoder_frame > 0.0) { From 64e61a8beef2985c35ce517243c7927b079dd5be Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Thu, 8 Jan 2026 17:18:22 -0500 Subject: [PATCH 6/9] Choose type of timestamps you want --- examples/models/parakeet/main.cpp | 203 ++++++++++++++++++++---------- 1 file changed, 140 insertions(+), 63 deletions(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 225f7588927..f655448eee3 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -35,10 +35,10 @@ DEFINE_string( data_path, "", "Path to data file (.ptd) for delegate data (optional, required for CUDA)."); -DEFINE_bool( +DEFINE_string( timestamps, - false, - "Output subword-, word-, and segment-level timestamps (requires model metadata)."); + "none", + "Timestamp output mode: none|subword|word|segment|all"); using ::executorch::extension::from_blob; using ::executorch::extension::Module; @@ -51,13 +51,13 @@ namespace { // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/models/rnnt_models.py#L230-L238 const std::vector DURATIONS = {0, 1, 2, 3, 4}; -struct TokenTimestamp { +struct Token { int64_t id; int64_t start_offset; // encoder frame index int64_t end_offset; // encoder frame index }; -struct SubwordTimestamp { +struct TimestampedSubword { std::string text; int64_t start_offset; int64_t end_offset; @@ -65,7 +65,7 @@ struct SubwordTimestamp { double end_sec; }; -struct WordTimestamp { +struct TimestampedWord { std::string text; int64_t start_offset; int64_t end_offset; @@ -73,7 +73,7 @@ struct WordTimestamp { double end_sec; }; -struct SegmentTimestamp { +struct TimestampedSegment { std::string text; int64_t start_offset; int64_t end_offset; @@ -101,13 +101,83 @@ size_t ltrim_ascii_whitespace(const std::string& s) { return i; } -std::vector tokens_to_subword_timestamps( - const std::vector& tokens, +struct TimestampOutputMode { + bool subword = false; + bool word = false; + bool segment = false; + + bool enabled() const { + return subword || word || segment; + } +}; + +std::string to_lower_ascii(std::string s) { + for (char& ch : s) { + ch = static_cast(std::tolower(static_cast(ch))); + } + return s; +} + +bool parse_timestamp_output_mode( + const std::string& raw, + TimestampOutputMode* out, + std::string* error) { + if (!out) { + if (error) { + *error = "Internal error: TimestampOutputMode output was null."; + } + return false; + } + + const std::string mode = to_lower_ascii(raw); + if (mode == "none") { + *out = TimestampOutputMode{}; + return true; + } + if (mode == "subword") { + out->subword = true; + out->word = false; + out->segment = false; + return true; + } + if (mode == "word") { + out->subword = false; + out->word = true; + out->segment = false; + return true; + } + if (mode == "segment") { + out->subword = false; + out->word = false; + out->segment = true; + return true; + } + if (mode == "all") { + out->subword = true; + out->word = true; + out->segment = true; + return true; + } + + if (error) { + if (raw.empty()) { + *error = + "Invalid --timestamps value (empty). Expected: none, subword, word, segment, all."; + } else { + *error = "Invalid --timestamps value '" + raw + + "'. Expected: none, subword, word, segment, all."; + } + } + return false; +} + +std::vector tokens_to_timestamped_subwords( + const std::vector& tokens, tokenizers::Tokenizer* tokenizer, double seconds_per_encoder_frame) { // NeMo reference of TDT per-token "char" timestamp computation: // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L991 - std::vector subwords; + std::vector subwords; if (!tokenizer) { return subwords; } @@ -116,14 +186,14 @@ std::vector tokens_to_subword_timestamps( int64_t prev_end_offset = 0; bool has_prev_end_offset = false; - for (const auto& token_ts : tokens) { - uint64_t token = static_cast(token_ts.id); - auto decode_result = tokenizer->decode(bos_token, token); + for (const auto& token : tokens) { + uint64_t token_id = static_cast(token.id); + auto decode_result = tokenizer->decode(bos_token, token_id); std::string piece = decode_result.ok() ? decode_result.get() : std::string(); - TokenTimestamp adjusted = token_ts; + Token adjusted = token; size_t non_ws = ltrim_ascii_whitespace(piece); std::string trimmed_piece = piece.substr(non_ws); @@ -141,7 +211,7 @@ std::vector tokens_to_subword_timestamps( } subwords.push_back( - SubwordTimestamp{ + TimestampedSubword{ piece, adjusted.start_offset, adjusted.end_offset, @@ -155,18 +225,18 @@ std::vector tokens_to_subword_timestamps( return subwords; } -std::vector tokens_to_word_timestamps( - const std::vector& tokens, +std::vector tokens_to_timestamped_words( + const std::vector& tokens, tokenizers::Tokenizer* tokenizer, double seconds_per_encoder_frame) { // NeMo reference for word grouping (subword/char offsets -> word offsets): // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224 - std::vector words; + std::vector words; if (!tokenizer || tokens.empty()) { return words; } - uint64_t prev_token = 0; + uint64_t prev_token_id = 0; std::string current_word; int64_t current_start_offset = 0; @@ -185,7 +255,7 @@ std::vector tokens_to_word_timestamps( end_sec = seconds_per_encoder_frame * current_end_offset; } words.push_back( - WordTimestamp{ + TimestampedWord{ current_word, current_start_offset, current_end_offset, @@ -194,10 +264,10 @@ std::vector tokens_to_word_timestamps( current_word.clear(); }; - for (const auto& token_ts : tokens) { - uint64_t token = static_cast(token_ts.id); - auto decode_result = tokenizer->decode(prev_token, token); - prev_token = token; + for (const auto& token : tokens) { + uint64_t token_id = static_cast(token.id); + auto decode_result = tokenizer->decode(prev_token_id, token_id); + prev_token_id = token_id; std::string piece = decode_result.ok() ? decode_result.get() : std::string(); @@ -209,12 +279,12 @@ std::vector tokens_to_word_timestamps( continue; } - // TDT sometimes emits punctuation long after preceding token. Thus, pin to - // previous token. NeMo applies the same correction: + // TDT sometimes emits punctuation long after preceding token. Thus, pin + // timestamp to previous token. NeMo applies the same correction: // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 // Divergence: NeMo consults `supported_punctuation` from the model; here we // approximate with `is_ascii_punctuation_only()`. - TokenTimestamp adjusted = token_ts; + Token adjusted = token; const bool is_punct = is_ascii_punctuation_only(trimmed_piece); if (is_punct && has_prev_end_offset) { adjusted.start_offset = prev_end_offset; @@ -246,12 +316,12 @@ std::vector tokens_to_word_timestamps( return words; } -std::vector words_to_segment_timestamps( - const std::vector& words, +std::vector timestamped_words_to_timestamped_segments( + const std::vector& words, double seconds_per_encoder_frame) { // NeMo reference for segment grouping (word offsets -> segment offsets): // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327 - std::vector segments; + std::vector segments; if (words.empty()) { return segments; } @@ -272,7 +342,7 @@ std::vector words_to_segment_timestamps( end_sec = seconds_per_encoder_frame * segment_end_offset; } segments.push_back( - SegmentTimestamp{ + TimestampedSegment{ current_segment, segment_start_offset, segment_end_offset, @@ -309,7 +379,7 @@ std::vector words_to_segment_timestamps( return segments; } -std::vector greedy_decode_executorch( +std::vector greedy_decode_executorch( Module& model, const ::executorch::aten::Tensor& encoder_output, int64_t encoder_len, @@ -318,7 +388,7 @@ std::vector greedy_decode_executorch( int64_t num_rnn_layers = 2, int64_t pred_hidden = 640, int64_t max_symbols_per_step = 10) { - std::vector hypothesis; + std::vector hypothesis; int64_t num_token_classes = vocab_size + 1; // Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim] @@ -473,7 +543,7 @@ std::vector greedy_decode_executorch( t += std::max(dur, (int64_t)1); symbols_on_frame = 0; } else { - hypothesis.push_back(TokenTimestamp{k, t, t + dur}); + hypothesis.push_back(Token{k, t, t + dur}); // Update decoder state std::vector token_data = {k}; @@ -556,6 +626,14 @@ std::string tokens_to_text( int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); + TimestampOutputMode timestamp_mode; + std::string timestamp_error; + if (!parse_timestamp_output_mode( + FLAGS_timestamps, ×tamp_mode, ×tamp_error)) { + ET_LOG(Error, "%s", timestamp_error.c_str()); + return 1; + } + if (FLAGS_audio_path.empty()) { ET_LOG(Error, "audio_path flag must be provided."); return 1; @@ -702,9 +780,9 @@ int main(int argc, char** argv) { token_ids.push_back(t.id); } std::string text = tokens_to_text(token_ids, tokenizer.get()); - std::cout << "Transcription tokens: " << text << std::endl; + std::cout << "Transcribed text: " << text << std::endl; - if (FLAGS_timestamps) { + if (timestamp_mode.enabled()) { std::vector<::executorch::runtime::EValue> empty_inputs; auto window_stride_result = model->execute("window_stride", empty_inputs); auto subsampling_factor_result = @@ -725,46 +803,45 @@ int main(int argc, char** argv) { } else { ET_LOG( Error, - "Timestamps requested but model metadata is missing. Re-export the model with constant_methods for window_stride and encoder_subsampling_factor."); + "Timestamps requested (--timestamps=%s) but model metadata is missing. Re-export the model with constant_methods for window_stride and encoder_subsampling_factor.", + FLAGS_timestamps.c_str()); return 1; } - auto subwords = tokens_to_subword_timestamps( - tokens, tokenizer.get(), seconds_per_encoder_frame); - auto words = tokens_to_word_timestamps( - tokens, tokenizer.get(), seconds_per_encoder_frame); - auto segments = - words_to_segment_timestamps(words, seconds_per_encoder_frame); - std::cout << std::fixed << std::setprecision(2); - std::cout << "\nSubword timestamps:\n"; - for (const auto& sw : subwords) { - if (seconds_per_encoder_frame > 0.0) { + if (timestamp_mode.subword) { + std::vector subwords = tokens_to_timestamped_subwords( + tokens, tokenizer.get(), seconds_per_encoder_frame); + std::cout << "\nSubword timestamps:\n"; + for (const auto& sw : subwords) { std::cout << "[" << sw.start_sec << ", " << sw.end_sec << "] "; - } else { - std::cout << "[" << sw.start_offset << ", " << sw.end_offset << "] "; + std::cout << sw.text << "\n"; } - std::cout << sw.text << "\n"; } - std::cout << "\nWord timestamps:\n"; - for (const auto& w : words) { - if (seconds_per_encoder_frame > 0.0) { + std::vector words; + if (timestamp_mode.word || timestamp_mode.segment) { + words = tokens_to_timestamped_words( + tokens, tokenizer.get(), seconds_per_encoder_frame); + } + std::vector segments; + if (timestamp_mode.segment) { + segments = timestamped_words_to_timestamped_segments( + words, seconds_per_encoder_frame); + } + if (timestamp_mode.word) { + std::cout << "\nWord timestamps:\n"; + for (const auto& w : words) { std::cout << "[" << w.start_sec << ", " << w.end_sec << "] "; - } else { - std::cout << "[" << w.start_offset << ", " << w.end_offset << "] "; + std::cout << w.text << "\n"; } - std::cout << w.text << "\n"; } - - std::cout << "\nSegment timestamps:\n"; - for (const auto& s : segments) { - if (seconds_per_encoder_frame > 0.0) { + if (timestamp_mode.segment) { + std::cout << "\nSegment timestamps:\n"; + for (const auto& s : segments) { std::cout << "[" << s.start_sec << ", " << s.end_sec << "] "; - } else { - std::cout << "[" << s.start_offset << ", " << s.end_offset << "] "; + std::cout << s.text << "\n"; } - std::cout << s.text << "\n"; } } From f322904537f3f2fc5b3cd7ac217b35cf362960a4 Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Fri, 9 Jan 2026 13:19:58 -0500 Subject: [PATCH 7/9] cleanups --- examples/models/parakeet/.clang-format | 244 ++++++++++++++++++++++ examples/models/parakeet/README.md | 2 +- examples/models/parakeet/main.cpp | 275 +++++++++++-------------- 3 files changed, 371 insertions(+), 150 deletions(-) create mode 100644 examples/models/parakeet/.clang-format diff --git a/examples/models/parakeet/.clang-format b/examples/models/parakeet/.clang-format new file mode 100644 index 00000000000..8ec7b569e24 --- /dev/null +++ b/examples/models/parakeet/.clang-format @@ -0,0 +1,244 @@ +--- +Language: Cpp +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignArrayOfStructures: None +AlignConsecutiveAssignments: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + AlignFunctionPointers: false + PadOperators: true +AlignConsecutiveBitFields: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + AlignFunctionPointers: false + PadOperators: true +AlignConsecutiveDeclarations: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + AlignFunctionPointers: false + PadOperators: true +AlignConsecutiveMacros: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCompound: false + AlignFunctionPointers: false + PadOperators: true +AlignConsecutiveShortCaseStatements: + Enabled: false + AcrossEmptyLines: false + AcrossComments: false + AlignCaseColons: false +AlignEscapedNewlines: Left +AlignOperands: DontAlign +AlignTrailingComments: + Kind: Never + OverEmptyLines: 0 +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: false +AllowBreakBeforeNoexceptSpecifier: Never +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortCompoundRequirementOnASingleLine: true +AllowShortEnumsOnASingleLine: true +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: All +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: Yes +AttributeMacros: + - __capability +BinPackArguments: false +BinPackParameters: false +BitFieldColonSpacing: Both +BraceWrapping: + AfterCaseLabel: false + AfterClass: false + AfterControlStatement: Never + AfterEnum: false + AfterExternBlock: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + BeforeLambdaBody: false + BeforeWhile: false + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: true + SplitEmptyNamespace: true +BreakAdjacentStringLiterals: true +BreakAfterAttributes: Leave +BreakAfterJavaFieldAnnotations: false +BreakArrays: true +BreakBeforeBinaryOperators: None +BreakBeforeConceptDeclarations: Always +BreakBeforeBraces: Attach +BreakBeforeInlineASMColon: OnlyMultiline +BreakBeforeTernaryOperators: true +BreakConstructorInitializers: BeforeColon +BreakInheritanceList: BeforeColon +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +EmptyLineAfterAccessModifier: Never +EmptyLineBeforeAccessModifier: LogicalBlock +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - FOR_EACH + - FOR_EACH_R + - FOR_EACH_RANGE +IfMacros: + - KJ_IF_MAYBE +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + SortPriority: 0 + CaseSensitive: false + - Regex: '^<.*' + Priority: 2 + SortPriority: 0 + CaseSensitive: false + - Regex: '.*' + Priority: 3 + SortPriority: 0 + CaseSensitive: false +IncludeIsMainRegex: '(Test)?$' +IncludeIsMainSourceRegex: '' +IndentAccessModifiers: false +IndentCaseBlocks: false +IndentCaseLabels: true +IndentExternBlock: AfterExternBlock +IndentGotoLabels: true +IndentPPDirectives: None +IndentRequiresClause: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +InsertBraces: false +InsertNewlineAtEOF: false +InsertTrailingCommas: None +IntegerLiteralSeparator: + Binary: 0 + BinaryMinDigits: 0 + Decimal: 0 + DecimalMinDigits: 0 + Hex: 0 + HexMinDigits: 0 +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +KeepEmptyLinesAtEOF: false +LambdaBodyIndentation: Signature +LineEnding: DeriveLF +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 2 +ObjCBreakBeforeNestedBlockParam: true +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PackConstructorInitializers: NextLine +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakOpenParenthesis: 0 +PenaltyBreakScopeResolution: 500 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyIndentedWhitespace: 0 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +PPIndentWidth: -1 +QualifierAlignment: Leave +ReferenceAlignment: Pointer +ReflowComments: true +RemoveBracesLLVM: false +RemoveParentheses: Leave +RemoveSemicolon: false +RequiresClausePosition: OwnLine +RequiresExpressionIndentation: OuterScope +SeparateDefinitionBlocks: Leave +ShortNamespaceLines: 1 +SkipMacroDefinitionBody: false +SortIncludes: CaseSensitive +SortJavaStaticImport: Before +SortUsingDeclarations: LexicographicNumeric +SpaceAfterCStyleCast: false +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceAroundPointerQualifiers: Default +SpaceBeforeAssignmentOperators: true +SpaceBeforeCaseColon: false +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeJsonColon: false +SpaceBeforeParens: ControlStatements +SpaceBeforeParensOptions: + AfterControlStatements: true + AfterForeachMacros: true + AfterFunctionDefinitionName: false + AfterFunctionDeclarationName: false + AfterIfMacros: true + AfterOverloadedOperator: false + AfterPlacementOperator: true + AfterRequiresInClause: false + AfterRequiresInExpression: false + BeforeNonEmptyParentheses: false +SpaceBeforeRangeBasedForLoopColon: true +SpaceBeforeSquareBrackets: false +SpaceInEmptyBlock: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: Never +SpacesInContainerLiterals: true +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +SpacesInParens: Never +SpacesInParensOptions: + InCStyleCasts: false + InConditionalStatements: false + InEmptyParentheses: false + Other: false +SpacesInSquareBrackets: false +Standard: Latest +StatementAttributeLikeMacros: + - Q_EMIT +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +TabWidth: 8 +UseTab: Never +VerilogBreakBetweenInstancePorts: true +WhitespaceSensitiveMacros: + - BOOST_PP_STRINGIZE + - CF_SWIFT_NAME + - NS_SWIFT_NAME + - PP_STRINGIZE + - STRINGIZE +... diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index 68eab0f27bf..ab4eb8640b3 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -71,4 +71,4 @@ From the executorch root directory: | `--audio_path` | Path to input audio file (.wav) | | `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | | `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) | -| `--timestamps` | Print word/segment timestamps (requires `window_stride` + `encoder_subsampling_factor` in model metadata) | +| `--timestamps` | Timestamp output mode: none\|subword\|word\|segment\|all (requires `window_stride` + `encoder_subsampling_factor` in model metadata) | diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index f655448eee3..f9825fed975 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -6,14 +6,16 @@ * LICENSE file in the root directory of this source tree. */ +#include #include -#include +#include #include -#include #include #include #include #include +#include +#include #include #include @@ -57,7 +59,7 @@ struct Token { int64_t end_offset; // encoder frame index }; -struct TimestampedSubword { +struct TimestampedTextSpan { std::string text; int64_t start_offset; int64_t end_offset; @@ -65,23 +67,7 @@ struct TimestampedSubword { double end_sec; }; -struct TimestampedWord { - std::string text; - int64_t start_offset; - int64_t end_offset; - double start_sec; - double end_sec; -}; - -struct TimestampedSegment { - std::string text; - int64_t start_offset; - int64_t end_offset; - double start_sec; - double end_sec; -}; - -bool is_ascii_punctuation_only(const std::string& s) { +bool is_ascii_punctuation_only(std::string_view s) { if (s.empty()) { return false; } @@ -93,7 +79,7 @@ bool is_ascii_punctuation_only(const std::string& s) { return true; } -size_t ltrim_ascii_whitespace(const std::string& s) { +size_t ltrim_ascii_whitespace(std::string_view s) { size_t i = 0; while (i < s.size() && std::isspace(static_cast(s[i]))) { i++; @@ -101,6 +87,59 @@ size_t ltrim_ascii_whitespace(const std::string& s) { return i; } +struct DecodedPiece { + std::string piece; + size_t trimmed_offset = 0; + bool had_leading_whitespace = false; + + std::string_view trimmed_piece() const { + std::string_view trimmed(piece); + trimmed.remove_prefix(trimmed_offset); + return trimmed; + } +}; + +DecodedPiece decode_piece( + tokenizers::Tokenizer& tokenizer, + uint64_t prev_token_id, + uint64_t token_id) { + auto decode_result = tokenizer.decode(prev_token_id, token_id); + DecodedPiece decoded; + decoded.piece = decode_result.ok() ? decode_result.get() : std::string(); + decoded.trimmed_offset = ltrim_ascii_whitespace(decoded.piece); + decoded.had_leading_whitespace = decoded.trimmed_offset > 0; + return decoded; +} + +TimestampedTextSpan make_timestamped_span( + std::string text, + int64_t start_offset, + int64_t end_offset, + double seconds_per_encoder_frame) { + return TimestampedTextSpan{ + std::move(text), + start_offset, + end_offset, + seconds_per_encoder_frame * start_offset, + seconds_per_encoder_frame * end_offset}; +} + +Token apply_tdt_punctuation_timestamp_correction( + const Token& token, + bool is_punct, + int64_t prev_end_offset, + bool has_prev_end_offset) { + // TDT sometimes emits punctuation long after preceding token. Thus, pin + // timestamp to previous token. NeMo applies the same correction: + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 + // Divergence: NeMo consults `supported_punctuation` from the model; here we + // approximate punctuation detection with `is_ascii_punctuation_only()`. + if (!is_punct || !has_prev_end_offset) { + return token; + } + return Token{token.id, prev_end_offset, prev_end_offset}; +} + struct TimestampOutputMode { bool subword = false; bool word = false; @@ -171,52 +210,31 @@ bool parse_timestamp_output_mode( return false; } -std::vector tokens_to_timestamped_subwords( +std::vector tokens_to_timestamped_subwords( const std::vector& tokens, - tokenizers::Tokenizer* tokenizer, + tokenizers::Tokenizer& tokenizer, double seconds_per_encoder_frame) { // NeMo reference of TDT per-token "char" timestamp computation: // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L991 - std::vector subwords; - if (!tokenizer) { - return subwords; - } + std::vector subwords; + subwords.reserve(tokens.size()); - const uint64_t bos_token = tokenizer->bos_tok(); + const uint64_t bos_token = tokenizer.bos_tok(); int64_t prev_end_offset = 0; bool has_prev_end_offset = false; for (const auto& token : tokens) { uint64_t token_id = static_cast(token.id); - auto decode_result = tokenizer->decode(bos_token, token_id); + DecodedPiece decoded = decode_piece(tokenizer, bos_token, token_id); + const bool is_punct = is_ascii_punctuation_only(decoded.trimmed_piece()); + const Token adjusted = apply_tdt_punctuation_timestamp_correction( + token, is_punct, prev_end_offset, has_prev_end_offset); - std::string piece = - decode_result.ok() ? decode_result.get() : std::string(); - - Token adjusted = token; - size_t non_ws = ltrim_ascii_whitespace(piece); - std::string trimmed_piece = piece.substr(non_ws); - - const bool is_punct = is_ascii_punctuation_only(trimmed_piece); - if (is_punct && has_prev_end_offset) { - adjusted.start_offset = prev_end_offset; - adjusted.end_offset = prev_end_offset; - } - - double start_sec = -1.0; - double end_sec = -1.0; - if (seconds_per_encoder_frame > 0.0) { - start_sec = seconds_per_encoder_frame * adjusted.start_offset; - end_sec = seconds_per_encoder_frame * adjusted.end_offset; - } - - subwords.push_back( - TimestampedSubword{ - piece, - adjusted.start_offset, - adjusted.end_offset, - start_sec, - end_sec}); + subwords.push_back(make_timestamped_span( + std::move(decoded.piece), + adjusted.start_offset, + adjusted.end_offset, + seconds_per_encoder_frame)); prev_end_offset = adjusted.end_offset; has_prev_end_offset = true; @@ -225,16 +243,13 @@ std::vector tokens_to_timestamped_subwords( return subwords; } -std::vector tokens_to_timestamped_words( +std::vector tokens_to_timestamped_words( const std::vector& tokens, - tokenizers::Tokenizer* tokenizer, + tokenizers::Tokenizer& tokenizer, double seconds_per_encoder_frame) { // NeMo reference for word grouping (subword/char offsets -> word offsets): // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224 - std::vector words; - if (!tokenizer || tokens.empty()) { - return words; - } + std::vector words; uint64_t prev_token_id = 0; @@ -248,63 +263,41 @@ std::vector tokens_to_timestamped_words( if (current_word.empty()) { return; } - double start_sec = -1.0; - double end_sec = -1.0; - if (seconds_per_encoder_frame > 0.0) { - start_sec = seconds_per_encoder_frame * current_start_offset; - end_sec = seconds_per_encoder_frame * current_end_offset; - } - words.push_back( - TimestampedWord{ - current_word, - current_start_offset, - current_end_offset, - start_sec, - end_sec}); + words.push_back(make_timestamped_span( + std::move(current_word), + current_start_offset, + current_end_offset, + seconds_per_encoder_frame)); current_word.clear(); }; for (const auto& token : tokens) { uint64_t token_id = static_cast(token.id); - auto decode_result = tokenizer->decode(prev_token_id, token_id); + DecodedPiece decoded = decode_piece(tokenizer, prev_token_id, token_id); prev_token_id = token_id; - - std::string piece = - decode_result.ok() ? decode_result.get() : std::string(); - size_t non_ws = ltrim_ascii_whitespace(piece); - bool had_leading_ws = non_ws > 0; - std::string trimmed_piece = piece.substr(non_ws); - + const std::string_view trimmed_piece = decoded.trimmed_piece(); if (trimmed_piece.empty()) { continue; } - // TDT sometimes emits punctuation long after preceding token. Thus, pin - // timestamp to previous token. NeMo applies the same correction: - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 - // Divergence: NeMo consults `supported_punctuation` from the model; here we - // approximate with `is_ascii_punctuation_only()`. - Token adjusted = token; const bool is_punct = is_ascii_punctuation_only(trimmed_piece); - if (is_punct && has_prev_end_offset) { - adjusted.start_offset = prev_end_offset; - adjusted.end_offset = prev_end_offset; - } + const Token adjusted = apply_tdt_punctuation_timestamp_correction( + token, is_punct, prev_end_offset, has_prev_end_offset); if (current_word.empty()) { - current_word = trimmed_piece; + current_word.assign(trimmed_piece.data(), trimmed_piece.size()); current_start_offset = adjusted.start_offset; current_end_offset = adjusted.end_offset; - } else if (had_leading_ws && !is_punct) { + } else if (decoded.had_leading_whitespace && !is_punct) { // NeMo builds words from decoded token offsets w/ tokenizer-aware rules: // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99 // Here, just build words per-token and separate by leading whitespace emit_word(); - current_word = trimmed_piece; + current_word.assign(trimmed_piece.data(), trimmed_piece.size()); current_start_offset = adjusted.start_offset; current_end_offset = adjusted.end_offset; } else { - current_word += trimmed_piece; + current_word.append(trimmed_piece.data(), trimmed_piece.size()); current_end_offset = adjusted.end_offset; } @@ -316,12 +309,12 @@ std::vector tokens_to_timestamped_words( return words; } -std::vector timestamped_words_to_timestamped_segments( - const std::vector& words, +std::vector timestamped_words_to_timestamped_segments( + const std::vector& words, double seconds_per_encoder_frame) { // NeMo reference for segment grouping (word offsets -> segment offsets): // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327 - std::vector segments; + std::vector segments; if (words.empty()) { return segments; } @@ -335,19 +328,11 @@ std::vector timestamped_words_to_timestamped_segments( if (!has_segment || current_segment.empty()) { return; } - double start_sec = -1.0; - double end_sec = -1.0; - if (seconds_per_encoder_frame > 0.0) { - start_sec = seconds_per_encoder_frame * segment_start_offset; - end_sec = seconds_per_encoder_frame * segment_end_offset; - } - segments.push_back( - TimestampedSegment{ - current_segment, - segment_start_offset, - segment_end_offset, - start_sec, - end_sec}); + segments.push_back(make_timestamped_span( + std::move(current_segment), + segment_start_offset, + segment_end_offset, + seconds_per_encoder_frame)); current_segment.clear(); has_segment = false; }; @@ -604,23 +589,33 @@ std::vector greedy_decode_executorch( } std::string tokens_to_text( - const std::vector& tokens, - tokenizers::Tokenizer* tokenizer) { + const std::vector& tokens, + tokenizers::Tokenizer& tokenizer) { // Decode tokens to text one by one std::string result; uint64_t prev_token = 0; - for (size_t i = 0; i < tokens.size(); i++) { - uint64_t token = static_cast(tokens[i]); - auto decode_result = tokenizer->decode(prev_token, token); + for (const auto& token : tokens) { + uint64_t token_id = static_cast(token.id); + auto decode_result = tokenizer.decode(prev_token, token_id); if (decode_result.ok()) { result += decode_result.get(); } - prev_token = token; + prev_token = token_id; } return result; } +void print_timestamped_spans( + const char* label, + const std::vector& spans) { + std::cout << "\n" << label << " timestamps:\n"; + for (const auto& span : spans) { + std::cout << "[" << span.start_sec << ", " << span.end_sec << "] "; + std::cout << span.text << "\n"; + } +} + } // namespace int main(int argc, char** argv) { @@ -774,12 +769,7 @@ int main(int argc, char** argv) { } // Convert tokens to text - std::vector token_ids; - token_ids.reserve(tokens.size()); - for (const auto& t : tokens) { - token_ids.push_back(t.id); - } - std::string text = tokens_to_text(token_ids, tokenizer.get()); + std::string text = tokens_to_text(tokens, *tokenizer); std::cout << "Transcribed text: " << text << std::endl; if (timestamp_mode.enabled()) { @@ -810,38 +800,25 @@ int main(int argc, char** argv) { std::cout << std::fixed << std::setprecision(2); if (timestamp_mode.subword) { - std::vector subwords = tokens_to_timestamped_subwords( - tokens, tokenizer.get(), seconds_per_encoder_frame); - std::cout << "\nSubword timestamps:\n"; - for (const auto& sw : subwords) { - std::cout << "[" << sw.start_sec << ", " << sw.end_sec << "] "; - std::cout << sw.text << "\n"; - } + print_timestamped_spans( + "Subword", + tokens_to_timestamped_subwords( + tokens, *tokenizer, seconds_per_encoder_frame)); } - std::vector words; + std::vector words; if (timestamp_mode.word || timestamp_mode.segment) { words = tokens_to_timestamped_words( - tokens, tokenizer.get(), seconds_per_encoder_frame); - } - std::vector segments; - if (timestamp_mode.segment) { - segments = timestamped_words_to_timestamped_segments( - words, seconds_per_encoder_frame); + tokens, *tokenizer, seconds_per_encoder_frame); } if (timestamp_mode.word) { - std::cout << "\nWord timestamps:\n"; - for (const auto& w : words) { - std::cout << "[" << w.start_sec << ", " << w.end_sec << "] "; - std::cout << w.text << "\n"; - } + print_timestamped_spans("Word", words); } if (timestamp_mode.segment) { - std::cout << "\nSegment timestamps:\n"; - for (const auto& s : segments) { - std::cout << "[" << s.start_sec << ", " << s.end_sec << "] "; - std::cout << s.text << "\n"; - } + print_timestamped_spans( + "Segment", + timestamped_words_to_timestamped_segments( + words, seconds_per_encoder_frame)); } } From 9d31012c2d4f11b62e3eb7d209321baa6fda42cf Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Fri, 9 Jan 2026 13:20:30 -0500 Subject: [PATCH 8/9] no clang format --- .clang-format | 244 -------------------------------------------------- 1 file changed, 244 deletions(-) delete mode 100644 .clang-format diff --git a/.clang-format b/.clang-format deleted file mode 100644 index 8ec7b569e24..00000000000 --- a/.clang-format +++ /dev/null @@ -1,244 +0,0 @@ ---- -Language: Cpp -AccessModifierOffset: -1 -AlignAfterOpenBracket: AlwaysBreak -AlignArrayOfStructures: None -AlignConsecutiveAssignments: - Enabled: false - AcrossEmptyLines: false - AcrossComments: false - AlignCompound: false - AlignFunctionPointers: false - PadOperators: true -AlignConsecutiveBitFields: - Enabled: false - AcrossEmptyLines: false - AcrossComments: false - AlignCompound: false - AlignFunctionPointers: false - PadOperators: true -AlignConsecutiveDeclarations: - Enabled: false - AcrossEmptyLines: false - AcrossComments: false - AlignCompound: false - AlignFunctionPointers: false - PadOperators: true -AlignConsecutiveMacros: - Enabled: false - AcrossEmptyLines: false - AcrossComments: false - AlignCompound: false - AlignFunctionPointers: false - PadOperators: true -AlignConsecutiveShortCaseStatements: - Enabled: false - AcrossEmptyLines: false - AcrossComments: false - AlignCaseColons: false -AlignEscapedNewlines: Left -AlignOperands: DontAlign -AlignTrailingComments: - Kind: Never - OverEmptyLines: 0 -AllowAllArgumentsOnNextLine: true -AllowAllParametersOfDeclarationOnNextLine: false -AllowBreakBeforeNoexceptSpecifier: Never -AllowShortBlocksOnASingleLine: Never -AllowShortCaseLabelsOnASingleLine: false -AllowShortCompoundRequirementOnASingleLine: true -AllowShortEnumsOnASingleLine: true -AllowShortFunctionsOnASingleLine: Empty -AllowShortIfStatementsOnASingleLine: Never -AllowShortLambdasOnASingleLine: All -AllowShortLoopsOnASingleLine: false -AlwaysBreakAfterDefinitionReturnType: None -AlwaysBreakAfterReturnType: None -AlwaysBreakBeforeMultilineStrings: true -AlwaysBreakTemplateDeclarations: Yes -AttributeMacros: - - __capability -BinPackArguments: false -BinPackParameters: false -BitFieldColonSpacing: Both -BraceWrapping: - AfterCaseLabel: false - AfterClass: false - AfterControlStatement: Never - AfterEnum: false - AfterExternBlock: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - BeforeCatch: false - BeforeElse: false - BeforeLambdaBody: false - BeforeWhile: false - IndentBraces: false - SplitEmptyFunction: true - SplitEmptyRecord: true - SplitEmptyNamespace: true -BreakAdjacentStringLiterals: true -BreakAfterAttributes: Leave -BreakAfterJavaFieldAnnotations: false -BreakArrays: true -BreakBeforeBinaryOperators: None -BreakBeforeConceptDeclarations: Always -BreakBeforeBraces: Attach -BreakBeforeInlineASMColon: OnlyMultiline -BreakBeforeTernaryOperators: true -BreakConstructorInitializers: BeforeColon -BreakInheritanceList: BeforeColon -BreakStringLiterals: false -ColumnLimit: 80 -CommentPragmas: '^ IWYU pragma:' -CompactNamespaces: false -ConstructorInitializerIndentWidth: 4 -ContinuationIndentWidth: 4 -Cpp11BracedListStyle: true -DerivePointerAlignment: false -DisableFormat: false -EmptyLineAfterAccessModifier: Never -EmptyLineBeforeAccessModifier: LogicalBlock -ExperimentalAutoDetectBinPacking: false -FixNamespaceComments: true -ForEachMacros: - - FOR_EACH - - FOR_EACH_R - - FOR_EACH_RANGE -IfMacros: - - KJ_IF_MAYBE -IncludeBlocks: Preserve -IncludeCategories: - - Regex: '^<.*\.h(pp)?>' - Priority: 1 - SortPriority: 0 - CaseSensitive: false - - Regex: '^<.*' - Priority: 2 - SortPriority: 0 - CaseSensitive: false - - Regex: '.*' - Priority: 3 - SortPriority: 0 - CaseSensitive: false -IncludeIsMainRegex: '(Test)?$' -IncludeIsMainSourceRegex: '' -IndentAccessModifiers: false -IndentCaseBlocks: false -IndentCaseLabels: true -IndentExternBlock: AfterExternBlock -IndentGotoLabels: true -IndentPPDirectives: None -IndentRequiresClause: true -IndentWidth: 2 -IndentWrappedFunctionNames: false -InsertBraces: false -InsertNewlineAtEOF: false -InsertTrailingCommas: None -IntegerLiteralSeparator: - Binary: 0 - BinaryMinDigits: 0 - Decimal: 0 - DecimalMinDigits: 0 - Hex: 0 - HexMinDigits: 0 -JavaScriptQuotes: Leave -JavaScriptWrapImports: true -KeepEmptyLinesAtTheStartOfBlocks: false -KeepEmptyLinesAtEOF: false -LambdaBodyIndentation: Signature -LineEnding: DeriveLF -MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None -ObjCBinPackProtocolList: Auto -ObjCBlockIndentWidth: 2 -ObjCBreakBeforeNestedBlockParam: true -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: false -PackConstructorInitializers: NextLine -PenaltyBreakAssignment: 2 -PenaltyBreakBeforeFirstCallParameter: 1 -PenaltyBreakComment: 300 -PenaltyBreakFirstLessLess: 120 -PenaltyBreakOpenParenthesis: 0 -PenaltyBreakScopeResolution: 500 -PenaltyBreakString: 1000 -PenaltyBreakTemplateDeclaration: 10 -PenaltyExcessCharacter: 1000000 -PenaltyIndentedWhitespace: 0 -PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left -PPIndentWidth: -1 -QualifierAlignment: Leave -ReferenceAlignment: Pointer -ReflowComments: true -RemoveBracesLLVM: false -RemoveParentheses: Leave -RemoveSemicolon: false -RequiresClausePosition: OwnLine -RequiresExpressionIndentation: OuterScope -SeparateDefinitionBlocks: Leave -ShortNamespaceLines: 1 -SkipMacroDefinitionBody: false -SortIncludes: CaseSensitive -SortJavaStaticImport: Before -SortUsingDeclarations: LexicographicNumeric -SpaceAfterCStyleCast: false -SpaceAfterLogicalNot: false -SpaceAfterTemplateKeyword: true -SpaceAroundPointerQualifiers: Default -SpaceBeforeAssignmentOperators: true -SpaceBeforeCaseColon: false -SpaceBeforeCpp11BracedList: false -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeJsonColon: false -SpaceBeforeParens: ControlStatements -SpaceBeforeParensOptions: - AfterControlStatements: true - AfterForeachMacros: true - AfterFunctionDefinitionName: false - AfterFunctionDeclarationName: false - AfterIfMacros: true - AfterOverloadedOperator: false - AfterPlacementOperator: true - AfterRequiresInClause: false - AfterRequiresInExpression: false - BeforeNonEmptyParentheses: false -SpaceBeforeRangeBasedForLoopColon: true -SpaceBeforeSquareBrackets: false -SpaceInEmptyBlock: false -SpacesBeforeTrailingComments: 1 -SpacesInAngles: Never -SpacesInContainerLiterals: true -SpacesInLineCommentPrefix: - Minimum: 1 - Maximum: -1 -SpacesInParens: Never -SpacesInParensOptions: - InCStyleCasts: false - InConditionalStatements: false - InEmptyParentheses: false - Other: false -SpacesInSquareBrackets: false -Standard: Latest -StatementAttributeLikeMacros: - - Q_EMIT -StatementMacros: - - Q_UNUSED - - QT_REQUIRE_VERSION -TabWidth: 8 -UseTab: Never -VerilogBreakBetweenInstancePorts: true -WhitespaceSensitiveMacros: - - BOOST_PP_STRINGIZE - - CF_SWIFT_NAME - - NS_SWIFT_NAME - - PP_STRINGIZE - - STRINGIZE -... From 6744199d897699bf4c70b565a9bec3573995a9cb Mon Sep 17 00:00:00 2001 From: Matt Clayton Date: Fri, 9 Jan 2026 14:06:41 -0500 Subject: [PATCH 9/9] More cleanup --- examples/models/parakeet/main.cpp | 274 +++++++++++++----------------- 1 file changed, 118 insertions(+), 156 deletions(-) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index f9825fed975..be9f9bfd2b3 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -67,19 +67,16 @@ struct TimestampedTextSpan { double end_sec; }; -bool is_ascii_punctuation_only(std::string_view s) { +bool is_ascii_punctuation_only(const std::string_view s) { if (s.empty()) { return false; } - for (unsigned char ch : s) { - if (!std::ispunct(ch)) { - return false; - } - } - return true; + return std::all_of(s.begin(), s.end(), [](const unsigned char ch) { + return std::ispunct(ch); + }); } -size_t ltrim_ascii_whitespace(std::string_view s) { +size_t ltrim_ascii_whitespace(const std::string_view s) { size_t i = 0; while (i < s.size() && std::isspace(static_cast(s[i]))) { i++; @@ -87,7 +84,7 @@ size_t ltrim_ascii_whitespace(std::string_view s) { return i; } -struct DecodedPiece { +struct TokenizerDecodedPiece { std::string piece; size_t trimmed_offset = 0; bool had_leading_whitespace = false; @@ -99,41 +96,28 @@ struct DecodedPiece { } }; -DecodedPiece decode_piece( - tokenizers::Tokenizer& tokenizer, - uint64_t prev_token_id, - uint64_t token_id) { +TokenizerDecodedPiece decode_piece( + const tokenizers::Tokenizer& tokenizer, + const uint64_t prev_token_id, + const uint64_t token_id) { auto decode_result = tokenizer.decode(prev_token_id, token_id); - DecodedPiece decoded; + TokenizerDecodedPiece decoded; decoded.piece = decode_result.ok() ? decode_result.get() : std::string(); decoded.trimmed_offset = ltrim_ascii_whitespace(decoded.piece); decoded.had_leading_whitespace = decoded.trimmed_offset > 0; return decoded; } -TimestampedTextSpan make_timestamped_span( - std::string text, - int64_t start_offset, - int64_t end_offset, - double seconds_per_encoder_frame) { - return TimestampedTextSpan{ - std::move(text), - start_offset, - end_offset, - seconds_per_encoder_frame * start_offset, - seconds_per_encoder_frame * end_offset}; -} - +// TDT sometimes emits punctuation long after preceding token. Thus, pin +// timestamp to previous token. NeMo applies the same correction: +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 +// Divergence: NeMo consults `supported_punctuation` from the model; here we +// approximate punctuation detection with `is_ascii_punctuation_only()`. Token apply_tdt_punctuation_timestamp_correction( const Token& token, - bool is_punct, - int64_t prev_end_offset, - bool has_prev_end_offset) { - // TDT sometimes emits punctuation long after preceding token. Thus, pin - // timestamp to previous token. NeMo applies the same correction: - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 - // Divergence: NeMo consults `supported_punctuation` from the model; here we - // approximate punctuation detection with `is_ascii_punctuation_only()`. + const bool is_punct, + const int64_t prev_end_offset, + const bool has_prev_end_offset) { if (!is_punct || !has_prev_end_offset) { return token; } @@ -157,63 +141,36 @@ std::string to_lower_ascii(std::string s) { return s; } -bool parse_timestamp_output_mode( - const std::string& raw, - TimestampOutputMode* out, - std::string* error) { - if (!out) { - if (error) { - *error = "Internal error: TimestampOutputMode output was null."; - } - return false; +TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) { + if (raw_arg.empty()) { + throw std::invalid_argument( + "Invalid --timestamps value (empty). Expected: subword, word, segment, all."); } - - const std::string mode = to_lower_ascii(raw); + const std::string mode = to_lower_ascii(raw_arg); if (mode == "none") { - *out = TimestampOutputMode{}; - return true; + return {false, false, false}; } if (mode == "subword") { - out->subword = true; - out->word = false; - out->segment = false; - return true; + return {true, false, false}; } if (mode == "word") { - out->subword = false; - out->word = true; - out->segment = false; - return true; + return {false, true, false}; } if (mode == "segment") { - out->subword = false; - out->word = false; - out->segment = true; - return true; + return {false, false, true}; } if (mode == "all") { - out->subword = true; - out->word = true; - out->segment = true; - return true; + return {true, true, true}; } - - if (error) { - if (raw.empty()) { - *error = - "Invalid --timestamps value (empty). Expected: none, subword, word, segment, all."; - } else { - *error = "Invalid --timestamps value '" + raw + - "'. Expected: none, subword, word, segment, all."; - } - } - return false; + throw std::invalid_argument( + "Invalid --timestamps value '" + raw_arg + + "'. Expected: subword, word, segment, all."); } std::vector tokens_to_timestamped_subwords( const std::vector& tokens, - tokenizers::Tokenizer& tokenizer, - double seconds_per_encoder_frame) { + const tokenizers::Tokenizer& tokenizer, + const double seconds_per_encoder_frame) { // NeMo reference of TDT per-token "char" timestamp computation: // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L991 std::vector subwords; @@ -224,17 +181,18 @@ std::vector tokens_to_timestamped_subwords( bool has_prev_end_offset = false; for (const auto& token : tokens) { - uint64_t token_id = static_cast(token.id); - DecodedPiece decoded = decode_piece(tokenizer, bos_token, token_id); + TokenizerDecodedPiece decoded = + decode_piece(tokenizer, bos_token, token.id); const bool is_punct = is_ascii_punctuation_only(decoded.trimmed_piece()); const Token adjusted = apply_tdt_punctuation_timestamp_correction( token, is_punct, prev_end_offset, has_prev_end_offset); - subwords.push_back(make_timestamped_span( - std::move(decoded.piece), - adjusted.start_offset, - adjusted.end_offset, - seconds_per_encoder_frame)); + subwords.push_back( + {std::string(decoded.trimmed_piece()), + adjusted.start_offset, + adjusted.end_offset, + seconds_per_encoder_frame * adjusted.start_offset, + seconds_per_encoder_frame * adjusted.end_offset}); prev_end_offset = adjusted.end_offset; has_prev_end_offset = true; @@ -245,14 +203,13 @@ std::vector tokens_to_timestamped_subwords( std::vector tokens_to_timestamped_words( const std::vector& tokens, - tokenizers::Tokenizer& tokenizer, - double seconds_per_encoder_frame) { + const tokenizers::Tokenizer& tokenizer, + const double seconds_per_encoder_frame) { // NeMo reference for word grouping (subword/char offsets -> word offsets): // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224 std::vector words; uint64_t prev_token_id = 0; - std::string current_word; int64_t current_start_offset = 0; int64_t current_end_offset = 0; @@ -263,18 +220,19 @@ std::vector tokens_to_timestamped_words( if (current_word.empty()) { return; } - words.push_back(make_timestamped_span( - std::move(current_word), - current_start_offset, - current_end_offset, - seconds_per_encoder_frame)); + words.push_back( + {std::move(current_word), + current_start_offset, + current_end_offset, + seconds_per_encoder_frame * current_start_offset, + seconds_per_encoder_frame * current_end_offset}); current_word.clear(); }; for (const auto& token : tokens) { - uint64_t token_id = static_cast(token.id); - DecodedPiece decoded = decode_piece(tokenizer, prev_token_id, token_id); - prev_token_id = token_id; + TokenizerDecodedPiece decoded = + decode_piece(tokenizer, prev_token_id, token.id); + prev_token_id = token.id; const std::string_view trimmed_piece = decoded.trimmed_piece(); if (trimmed_piece.empty()) { continue; @@ -328,11 +286,12 @@ std::vector timestamped_words_to_timestamped_segments( if (!has_segment || current_segment.empty()) { return; } - segments.push_back(make_timestamped_span( - std::move(current_segment), - segment_start_offset, - segment_end_offset, - seconds_per_encoder_frame)); + segments.push_back( + {std::move(current_segment), + segment_start_offset, + segment_end_offset, + seconds_per_encoder_frame * segment_start_offset, + seconds_per_encoder_frame * segment_end_offset}); current_segment.clear(); has_segment = false; }; @@ -590,17 +549,16 @@ std::vector greedy_decode_executorch( std::string tokens_to_text( const std::vector& tokens, - tokenizers::Tokenizer& tokenizer) { + const tokenizers::Tokenizer& tokenizer) { // Decode tokens to text one by one std::string result; uint64_t prev_token = 0; for (const auto& token : tokens) { - uint64_t token_id = static_cast(token.id); - auto decode_result = tokenizer.decode(prev_token, token_id); + auto decode_result = tokenizer.decode(prev_token, token.id); if (decode_result.ok()) { result += decode_result.get(); } - prev_token = token_id; + prev_token = token.id; } return result; @@ -610,10 +568,15 @@ void print_timestamped_spans( const char* label, const std::vector& spans) { std::cout << "\n" << label << " timestamps:\n"; + const std::ios_base::fmtflags old_flags = std::cout.flags(); + const std::streamsize old_precision = std::cout.precision(); + std::cout << std::fixed << std::setprecision(2); for (const auto& span : spans) { std::cout << "[" << span.start_sec << ", " << span.end_sec << "] "; std::cout << span.text << "\n"; } + std::cout.flags(old_flags); + std::cout.precision(old_precision); } } // namespace @@ -622,10 +585,10 @@ int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); TimestampOutputMode timestamp_mode; - std::string timestamp_error; - if (!parse_timestamp_output_mode( - FLAGS_timestamps, ×tamp_mode, ×tamp_error)) { - ET_LOG(Error, "%s", timestamp_error.c_str()); + try { + timestamp_mode = parse_timestamp_output_mode(FLAGS_timestamps); + } catch (const std::invalid_argument& e) { + ET_LOG(Error, "%s", e.what()); return 1; } @@ -713,7 +676,7 @@ int main(int argc, char** argv) { static_cast(encoded_len)); // Query model metadata from constant_methods - std::vector<::executorch::runtime::EValue> empty_inputs; + const std::vector<::executorch::runtime::EValue> empty_inputs; auto num_rnn_layers_result = model->execute("num_rnn_layers", empty_inputs); auto pred_hidden_result = model->execute("pred_hidden", empty_inputs); auto vocab_size_result = model->execute("vocab_size", empty_inputs); @@ -772,56 +735,55 @@ int main(int argc, char** argv) { std::string text = tokens_to_text(tokens, *tokenizer); std::cout << "Transcribed text: " << text << std::endl; - if (timestamp_mode.enabled()) { - std::vector<::executorch::runtime::EValue> empty_inputs; - auto window_stride_result = model->execute("window_stride", empty_inputs); - auto subsampling_factor_result = - model->execute("encoder_subsampling_factor", empty_inputs); - - double seconds_per_encoder_frame = -1.0; - if (window_stride_result.ok() && subsampling_factor_result.ok()) { - double window_stride = window_stride_result.get()[0].toDouble(); - int64_t encoder_subsampling_factor = - subsampling_factor_result.get()[0].toInt(); - seconds_per_encoder_frame = window_stride * encoder_subsampling_factor; - ET_LOG( - Info, - "Timestamp metadata: window_stride=%f, encoder_subsampling_factor=%lld, seconds_per_encoder_frame=%f", - window_stride, - static_cast(encoder_subsampling_factor), - seconds_per_encoder_frame); - } else { - ET_LOG( - Error, - "Timestamps requested (--timestamps=%s) but model metadata is missing. Re-export the model with constant_methods for window_stride and encoder_subsampling_factor.", - FLAGS_timestamps.c_str()); - return 1; - } + if (!timestamp_mode.enabled()) { + return 0; + } - std::cout << std::fixed << std::setprecision(2); - if (timestamp_mode.subword) { - print_timestamped_spans( - "Subword", - tokens_to_timestamped_subwords( - tokens, *tokenizer, seconds_per_encoder_frame)); - } + // Query timestamp metadata + auto window_stride_result = model->execute("window_stride", empty_inputs); + auto subsampling_factor_result = + model->execute("encoder_subsampling_factor", empty_inputs); + if (!window_stride_result.ok() || !subsampling_factor_result.ok()) { + ET_LOG( + Error, + "Timestamps requested (--timestamps=%s) but model metadata is missing. Re-export the model with constant_methods for window_stride and encoder_subsampling_factor.", + FLAGS_timestamps.c_str()); + return 1; + } - std::vector words; - if (timestamp_mode.word || timestamp_mode.segment) { - words = tokens_to_timestamped_words( - tokens, *tokenizer, seconds_per_encoder_frame); - } - if (timestamp_mode.word) { - print_timestamped_spans("Word", words); - } - if (timestamp_mode.segment) { - print_timestamped_spans( - "Segment", - timestamped_words_to_timestamped_segments( - words, seconds_per_encoder_frame)); - } + double window_stride = window_stride_result.get()[0].toDouble(); + int64_t encoder_subsampling_factor = + subsampling_factor_result.get()[0].toInt(); + const double seconds_per_encoder_frame = + window_stride * encoder_subsampling_factor; + ET_LOG( + Info, + "Timestamp metadata: window_stride=%f, encoder_subsampling_factor=%lld, seconds_per_encoder_frame=%f", + window_stride, + static_cast(encoder_subsampling_factor), + seconds_per_encoder_frame); + + if (timestamp_mode.subword) { + print_timestamped_spans( + "Subword", + tokens_to_timestamped_subwords( + tokens, *tokenizer, seconds_per_encoder_frame)); + } + + std::vector words; + if (timestamp_mode.word || timestamp_mode.segment) { + words = tokens_to_timestamped_words( + tokens, *tokenizer, seconds_per_encoder_frame); + } + if (timestamp_mode.word) { + print_timestamped_spans("Word", words); + } + if (timestamp_mode.segment) { + print_timestamped_spans( + "Segment", + timestamped_words_to_timestamped_segments( + words, seconds_per_encoder_frame)); } - ET_LOG(Info, "Done!"); return 0; }