diff --git a/.clang-format b/examples/models/parakeet/.clang-format similarity index 100% rename from .clang-format rename to examples/models/parakeet/.clang-format diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index b27bc1f8a91..ab4eb8640b3 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -71,3 +71,4 @@ From the executorch root directory: | `--audio_path` | Path to input audio file (.wav) | | `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | | `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) | +| `--timestamps` | Timestamp output mode: none\|subword\|word\|segment\|all (requires `window_stride` + `encoder_subsampling_factor` in model metadata) | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 92e32ca30bf..3990e1165d2 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -351,6 +351,8 @@ def export_all(model): ) sample_rate = model.preprocessor._cfg.sample_rate + window_stride = float(model.preprocessor._cfg.window_stride) + encoder_subsampling_factor = int(getattr(model.encoder, "subsampling_factor", 8)) metadata = { "num_rnn_layers": num_layers, "pred_hidden": pred_hidden, @@ -358,6 +360,8 @@ def export_all(model): "vocab_size": model.tokenizer.vocab_size, "blank_id": model.tokenizer.vocab_size, "sample_rate": sample_rate, + "window_stride": window_stride, + "encoder_subsampling_factor": encoder_subsampling_factor, } return programs, metadata diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index 026f3911a3d..be9f9bfd2b3 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -6,12 +6,16 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include +#include +#include #include -#include +#include #include #include #include +#include +#include #include #include @@ -33,6 +37,10 @@ DEFINE_string( data_path, "", "Path to data file (.ptd) for delegate data (optional, required for CUDA)."); +DEFINE_string( + timestamps, + "none", + "Timestamp output mode: none|subword|word|segment|all"); using ::executorch::extension::from_blob; using ::executorch::extension::Module; @@ -41,10 +49,281 @@ using ::executorch::runtime::EValue; namespace { -// TDT duration values +// TDT duration values (Comes from model config in NeMo implementation) +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/models/rnnt_models.py#L230-L238 const std::vector DURATIONS = {0, 1, 2, 3, 4}; -std::vector greedy_decode_executorch( +struct Token { + int64_t id; + int64_t start_offset; // encoder frame index + int64_t end_offset; // encoder frame index +}; + +struct TimestampedTextSpan { + std::string text; + int64_t start_offset; + int64_t end_offset; + double start_sec; + double end_sec; +}; + +bool is_ascii_punctuation_only(const std::string_view s) { + if (s.empty()) { + return false; + } + return std::all_of(s.begin(), s.end(), [](const unsigned char ch) { + return std::ispunct(ch); + }); +} + +size_t ltrim_ascii_whitespace(const std::string_view s) { + size_t i = 0; + while (i < s.size() && std::isspace(static_cast(s[i]))) { + i++; + } + return i; +} + +struct TokenizerDecodedPiece { + std::string piece; + size_t trimmed_offset = 0; + bool had_leading_whitespace = false; + + std::string_view trimmed_piece() const { + std::string_view trimmed(piece); + trimmed.remove_prefix(trimmed_offset); + return trimmed; + } +}; + +TokenizerDecodedPiece decode_piece( + const tokenizers::Tokenizer& tokenizer, + const uint64_t prev_token_id, + const uint64_t token_id) { + auto decode_result = tokenizer.decode(prev_token_id, token_id); + TokenizerDecodedPiece decoded; + decoded.piece = decode_result.ok() ? decode_result.get() : std::string(); + decoded.trimmed_offset = ltrim_ascii_whitespace(decoded.piece); + decoded.had_leading_whitespace = decoded.trimmed_offset > 0; + return decoded; +} + +// TDT sometimes emits punctuation long after preceding token. Thus, pin +// timestamp to previous token. NeMo applies the same correction: +// https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L1169-L1189 +// Divergence: NeMo consults `supported_punctuation` from the model; here we +// approximate punctuation detection with `is_ascii_punctuation_only()`. +Token apply_tdt_punctuation_timestamp_correction( + const Token& token, + const bool is_punct, + const int64_t prev_end_offset, + const bool has_prev_end_offset) { + if (!is_punct || !has_prev_end_offset) { + return token; + } + return Token{token.id, prev_end_offset, prev_end_offset}; +} + +struct TimestampOutputMode { + bool subword = false; + bool word = false; + bool segment = false; + + bool enabled() const { + return subword || word || segment; + } +}; + +std::string to_lower_ascii(std::string s) { + for (char& ch : s) { + ch = static_cast(std::tolower(static_cast(ch))); + } + return s; +} + +TimestampOutputMode parse_timestamp_output_mode(const std::string& raw_arg) { + if (raw_arg.empty()) { + throw std::invalid_argument( + "Invalid --timestamps value (empty). Expected: subword, word, segment, all."); + } + const std::string mode = to_lower_ascii(raw_arg); + if (mode == "none") { + return {false, false, false}; + } + if (mode == "subword") { + return {true, false, false}; + } + if (mode == "word") { + return {false, true, false}; + } + if (mode == "segment") { + return {false, false, true}; + } + if (mode == "all") { + return {true, true, true}; + } + throw std::invalid_argument( + "Invalid --timestamps value '" + raw_arg + + "'. Expected: subword, word, segment, all."); +} + +std::vector tokens_to_timestamped_subwords( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer, + const double seconds_per_encoder_frame) { + // NeMo reference of TDT per-token "char" timestamp computation: + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/rnnt_decoding.py#L991 + std::vector subwords; + subwords.reserve(tokens.size()); + + const uint64_t bos_token = tokenizer.bos_tok(); + int64_t prev_end_offset = 0; + bool has_prev_end_offset = false; + + for (const auto& token : tokens) { + TokenizerDecodedPiece decoded = + decode_piece(tokenizer, bos_token, token.id); + const bool is_punct = is_ascii_punctuation_only(decoded.trimmed_piece()); + const Token adjusted = apply_tdt_punctuation_timestamp_correction( + token, is_punct, prev_end_offset, has_prev_end_offset); + + subwords.push_back( + {std::string(decoded.trimmed_piece()), + adjusted.start_offset, + adjusted.end_offset, + seconds_per_encoder_frame * adjusted.start_offset, + seconds_per_encoder_frame * adjusted.end_offset}); + + prev_end_offset = adjusted.end_offset; + has_prev_end_offset = true; + } + + return subwords; +} + +std::vector tokens_to_timestamped_words( + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer, + const double seconds_per_encoder_frame) { + // NeMo reference for word grouping (subword/char offsets -> word offsets): + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L54-L224 + std::vector words; + + uint64_t prev_token_id = 0; + std::string current_word; + int64_t current_start_offset = 0; + int64_t current_end_offset = 0; + int64_t prev_end_offset = 0; + bool has_prev_end_offset = false; + + auto emit_word = [&]() { + if (current_word.empty()) { + return; + } + words.push_back( + {std::move(current_word), + current_start_offset, + current_end_offset, + seconds_per_encoder_frame * current_start_offset, + seconds_per_encoder_frame * current_end_offset}); + current_word.clear(); + }; + + for (const auto& token : tokens) { + TokenizerDecodedPiece decoded = + decode_piece(tokenizer, prev_token_id, token.id); + prev_token_id = token.id; + const std::string_view trimmed_piece = decoded.trimmed_piece(); + if (trimmed_piece.empty()) { + continue; + } + + const bool is_punct = is_ascii_punctuation_only(trimmed_piece); + const Token adjusted = apply_tdt_punctuation_timestamp_correction( + token, is_punct, prev_end_offset, has_prev_end_offset); + + if (current_word.empty()) { + current_word.assign(trimmed_piece.data(), trimmed_piece.size()); + current_start_offset = adjusted.start_offset; + current_end_offset = adjusted.end_offset; + } else if (decoded.had_leading_whitespace && !is_punct) { + // NeMo builds words from decoded token offsets w/ tokenizer-aware rules: + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L79-L99 + // Here, just build words per-token and separate by leading whitespace + emit_word(); + current_word.assign(trimmed_piece.data(), trimmed_piece.size()); + current_start_offset = adjusted.start_offset; + current_end_offset = adjusted.end_offset; + } else { + current_word.append(trimmed_piece.data(), trimmed_piece.size()); + current_end_offset = adjusted.end_offset; + } + + prev_end_offset = adjusted.end_offset; + has_prev_end_offset = true; + } + + emit_word(); + return words; +} + +std::vector timestamped_words_to_timestamped_segments( + const std::vector& words, + double seconds_per_encoder_frame) { + // NeMo reference for segment grouping (word offsets -> segment offsets): + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/utils/timestamp_utils.py#L227-L327 + std::vector segments; + if (words.empty()) { + return segments; + } + + std::string current_segment; + int64_t segment_start_offset = 0; + int64_t segment_end_offset = 0; + bool has_segment = false; + + auto emit_segment = [&]() { + if (!has_segment || current_segment.empty()) { + return; + } + segments.push_back( + {std::move(current_segment), + segment_start_offset, + segment_end_offset, + seconds_per_encoder_frame * segment_start_offset, + seconds_per_encoder_frame * segment_end_offset}); + current_segment.clear(); + has_segment = false; + }; + + for (const auto& word : words) { + if (!has_segment) { + has_segment = true; + current_segment = word.text; + segment_start_offset = word.start_offset; + segment_end_offset = word.end_offset; + } else { + current_segment += " "; + current_segment += word.text; + segment_end_offset = word.end_offset; + } + + if (!word.text.empty()) { + char last = word.text.back(); + // NeMo Divergence: we only segment on terminal punctuation (.,!,?) rather + // than configurable segment_delimiter_tokens. Also no + // `segment_gap_threshold` splitting. + if (last == '.' || last == '!' || last == '?') { + emit_segment(); + } + } + } + + emit_segment(); + return segments; +} + +std::vector greedy_decode_executorch( Module& model, const ::executorch::aten::Tensor& encoder_output, int64_t encoder_len, @@ -53,7 +332,7 @@ std::vector greedy_decode_executorch( int64_t num_rnn_layers = 2, int64_t pred_hidden = 640, int64_t max_symbols_per_step = 10) { - std::vector hypothesis; + std::vector hypothesis; int64_t num_token_classes = vocab_size + 1; // Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim] @@ -108,9 +387,9 @@ std::vector greedy_decode_executorch( // Prime the prediction network state with SOS (= blank_id) to match NeMo TDT // greedy label-looping decoding behavior: // - SOS is defined as blank: - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c980b70cecc184fa8a083a9c3ddb87f905e/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py#L250 + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py#L250 // - Predictor priming with SOS: - // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c980b70cecc184fa8a083a9c3ddb87f905e/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py#L363-L368 + // https://github.com/NVIDIA-NeMo/NeMo/blob/bf583c9/nemo/collections/asr/parts/submodules/transducer_decoding/tdt_label_looping.py#L363-L368 std::vector sos_token_data = {blank_id}; auto sos_token = from_blob( sos_token_data.data(), {1, 1}, ::executorch::aten::ScalarType::Long); @@ -208,7 +487,7 @@ std::vector greedy_decode_executorch( t += std::max(dur, (int64_t)1); symbols_on_frame = 0; } else { - hypothesis.push_back(k); + hypothesis.push_back(Token{k, t, t + dur}); // Update decoder state std::vector token_data = {k}; @@ -269,28 +548,50 @@ std::vector greedy_decode_executorch( } std::string tokens_to_text( - const std::vector& tokens, - tokenizers::Tokenizer* tokenizer) { + const std::vector& tokens, + const tokenizers::Tokenizer& tokenizer) { // Decode tokens to text one by one std::string result; uint64_t prev_token = 0; - for (size_t i = 0; i < tokens.size(); i++) { - uint64_t token = static_cast(tokens[i]); - auto decode_result = tokenizer->decode(prev_token, token); + for (const auto& token : tokens) { + auto decode_result = tokenizer.decode(prev_token, token.id); if (decode_result.ok()) { result += decode_result.get(); } - prev_token = token; + prev_token = token.id; } return result; } +void print_timestamped_spans( + const char* label, + const std::vector& spans) { + std::cout << "\n" << label << " timestamps:\n"; + const std::ios_base::fmtflags old_flags = std::cout.flags(); + const std::streamsize old_precision = std::cout.precision(); + std::cout << std::fixed << std::setprecision(2); + for (const auto& span : spans) { + std::cout << "[" << span.start_sec << ", " << span.end_sec << "] "; + std::cout << span.text << "\n"; + } + std::cout.flags(old_flags); + std::cout.precision(old_precision); +} + } // namespace int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); + TimestampOutputMode timestamp_mode; + try { + timestamp_mode = parse_timestamp_output_mode(FLAGS_timestamps); + } catch (const std::invalid_argument& e) { + ET_LOG(Error, "%s", e.what()); + return 1; + } + if (FLAGS_audio_path.empty()) { ET_LOG(Error, "audio_path flag must be provided."); return 1; @@ -375,7 +676,7 @@ int main(int argc, char** argv) { static_cast(encoded_len)); // Query model metadata from constant_methods - std::vector<::executorch::runtime::EValue> empty_inputs; + const std::vector<::executorch::runtime::EValue> empty_inputs; auto num_rnn_layers_result = model->execute("num_rnn_layers", empty_inputs); auto pred_hidden_result = model->execute("pred_hidden", empty_inputs); auto vocab_size_result = model->execute("vocab_size", empty_inputs); @@ -431,9 +732,58 @@ int main(int argc, char** argv) { } // Convert tokens to text - std::string text = tokens_to_text(tokens, tokenizer.get()); - std::cout << "Transcription tokens: " << text << std::endl; + std::string text = tokens_to_text(tokens, *tokenizer); + std::cout << "Transcribed text: " << text << std::endl; + + if (!timestamp_mode.enabled()) { + return 0; + } + + // Query timestamp metadata + auto window_stride_result = model->execute("window_stride", empty_inputs); + auto subsampling_factor_result = + model->execute("encoder_subsampling_factor", empty_inputs); + if (!window_stride_result.ok() || !subsampling_factor_result.ok()) { + ET_LOG( + Error, + "Timestamps requested (--timestamps=%s) but model metadata is missing. Re-export the model with constant_methods for window_stride and encoder_subsampling_factor.", + FLAGS_timestamps.c_str()); + return 1; + } + + double window_stride = window_stride_result.get()[0].toDouble(); + int64_t encoder_subsampling_factor = + subsampling_factor_result.get()[0].toInt(); + const double seconds_per_encoder_frame = + window_stride * encoder_subsampling_factor; + ET_LOG( + Info, + "Timestamp metadata: window_stride=%f, encoder_subsampling_factor=%lld, seconds_per_encoder_frame=%f", + window_stride, + static_cast(encoder_subsampling_factor), + seconds_per_encoder_frame); + + if (timestamp_mode.subword) { + print_timestamped_spans( + "Subword", + tokens_to_timestamped_subwords( + tokens, *tokenizer, seconds_per_encoder_frame)); + } + + std::vector words; + if (timestamp_mode.word || timestamp_mode.segment) { + words = tokens_to_timestamped_words( + tokens, *tokenizer, seconds_per_encoder_frame); + } + if (timestamp_mode.word) { + print_timestamped_spans("Word", words); + } + if (timestamp_mode.segment) { + print_timestamped_spans( + "Segment", + timestamped_words_to_timestamped_segments( + words, seconds_per_encoder_frame)); + } - ET_LOG(Info, "Done!"); return 0; }