diff --git a/src/bin/Makefile b/src/bin/Makefile index 171b73aa63b..058b51ecee5 100644 --- a/src/bin/Makefile +++ b/src/bin/Makefile @@ -28,7 +28,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \ OBJFILES = -ADDLIBS = ../cudamatrix/kaldi-cudamatrix.a ../nnet3/kaldi-nnet3.a ../rnnlm/kaldi-rnnlm.a ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a ../lm/kaldi-lm.a \ +ADDLIBS = ../decoder/kaldi-decoder.a ../lat/kaldi-lat.a ../lm/kaldi-lm.a \ ../fstext/kaldi-fstext.a ../hmm/kaldi-hmm.a \ ../transform/kaldi-transform.a ../gmm/kaldi-gmm.a \ ../tree/kaldi-tree.a ../util/kaldi-util.a \ diff --git a/src/bin/lat2gen-biglm-faster-mapped.cc b/src/bin/lat2gen-biglm-faster-mapped.cc index c186e1a476b..b31fd4b07d5 100644 --- a/src/bin/lat2gen-biglm-faster-mapped.cc +++ b/src/bin/lat2gen-biglm-faster-mapped.cc @@ -25,8 +25,6 @@ #include "fstext/fstext-lib.h" #include "decoder/decoder-wrappers.h" #include "decoder/decodable-matrix.h" -#include "lm/const-arpa-lm.h" -#include "rnnlm/rnnlm-lattice-rescoring.h" #include "base/timer.h" #include "decoder/lattice2-biglm-faster-decoder.h" @@ -162,26 +160,14 @@ int main(int argc, char *argv[]) { bool allow_partial = false; BaseFloat acoustic_scale = 0.1; Lattice2BiglmFasterDecoderConfig config; - int32 max_ngram_order = 4; - rnnlm::RnnlmComputeStateComputationOptions rnn_opts; - bool use_carpa = false; - std::string word_syms_filename, word_embedding_rxfilename; + std::string word_syms_filename; config.Register(&po); - rnn_opts.Register(&po); po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]"); po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached."); - po.Register("use-const-arpa", &use_carpa, "If true, read the old-LM file " - "as a const-arpa file as opposed to an FST file"); - po.Register("word-embedding-rxfilename", &word_embedding_rxfilename, "If set, use rnnlm"); - po.Register("max-ngram-order", &max_ngram_order, - "If positive, allow RNNLM histories longer than this to be identified " - "with each other for rescoring purposes (an approximation that " - "saves time and reduces output lattice size)."); - - + po.Read(argc, argv); if (po.NumArgs() < 6 || po.NumArgs() > 8) { @@ -201,39 +187,17 @@ int main(int argc, char *argv[]) { TransitionModel trans_model; ReadKaldiObject(model_in_filename, &trans_model); - VectorFst *old_lm_fst = fst::ReadAndPrepareLmFst( - old_lm_fst_rxfilename); - fst::BackoffDeterministicOnDemandFst old_lm_dfst(*old_lm_fst); - fst::ScaleDeterministicOnDemandFst old_lm_sdfst(-1, - &old_lm_dfst); - - fst::DeterministicOnDemandFst* new_lm_dfst = NULL; - VectorFst *new_lm_fst = NULL; - ConstArpaLm* const_arpa = NULL; - CuMatrix* word_embedding_mat = NULL; - kaldi::nnet3::Nnet *rnnlm = NULL; - const rnnlm::RnnlmComputeStateInfo *info = NULL; - - if (word_embedding_rxfilename!="") { - rnnlm = new kaldi::nnet3::Nnet(); - word_embedding_mat = new CuMatrix(); - ReadKaldiObject(word_embedding_rxfilename, word_embedding_mat); - ReadKaldiObject(new_lm_fst_rxfilename, rnnlm); - info = new rnnlm::RnnlmComputeStateInfo(rnn_opts, *rnnlm, *word_embedding_mat); - new_lm_dfst = new rnnlm::KaldiRnnlmDeterministicFst(max_ngram_order, *info); - } else if (use_carpa) { - const_arpa = new ConstArpaLm(); - ReadKaldiObject(new_lm_fst_rxfilename, const_arpa); - new_lm_dfst = new ConstArpaLmDeterministicFst(*const_arpa); - } else { - new_lm_fst = fst::ReadAndPrepareLmFst( - new_lm_fst_rxfilename); - new_lm_dfst = - new fst::BackoffDeterministicOnDemandFst(*new_lm_fst); - } + VectorFst *old_lm_fst = fst::CastOrConvertToVectorFst( + fst::ReadFstKaldiGeneric(old_lm_fst_rxfilename)); + ApplyProbabilityScale(-1.0, old_lm_fst); // Negate old LM probs... + + VectorFst *new_lm_fst = fst::CastOrConvertToVectorFst( + fst::ReadFstKaldiGeneric(new_lm_fst_rxfilename)); - fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_sdfst, - new_lm_dfst); + fst::BackoffDeterministicOnDemandFst old_lm_dfst(*old_lm_fst); + fst::BackoffDeterministicOnDemandFst new_lm_dfst(*new_lm_fst); + fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_dfst, + &new_lm_dfst); fst::CacheDeterministicOnDemandFst cache_dfst(&compose_dfst, 1e7); bool determinize = config.determinize_lattice; @@ -257,7 +221,6 @@ int main(int argc, char *argv[]) { double tot_like = 0.0; kaldi::int64 frame_count = 0; int num_success = 0, num_fail = 0; - double elapsed = 0; if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { @@ -292,13 +255,13 @@ int main(int argc, char *argv[]) { num_success++; } else num_fail++; } - elapsed = timer.Elapsed(); } delete decode_fst; // delete this only after decoder goes out of scope. } else { // We have different FSTs for different utterances. assert(0); } + double elapsed = timer.Elapsed(); KALDI_LOG << "Time taken "<< elapsed << "s: real-time factor assuming 100 frames/sec is " << (elapsed*100.0/frame_count); @@ -308,14 +271,6 @@ int main(int argc, char *argv[]) { << frame_count<<" frames."; delete word_syms; - - delete const_arpa; - delete new_lm_fst; - delete new_lm_dfst; - delete word_embedding_mat; - delete rnnlm; - delete info; - if (num_success != 0) return 0; else return 1; } catch(const std::exception &e) { diff --git a/src/bin/latgen-biglm-faster-mapped.cc b/src/bin/latgen-biglm-faster-mapped.cc index deff3b9093e..1f87572a4f3 100644 --- a/src/bin/latgen-biglm-faster-mapped.cc +++ b/src/bin/latgen-biglm-faster-mapped.cc @@ -25,8 +25,6 @@ #include "fstext/fstext-lib.h" #include "decoder/decoder-wrappers.h" #include "decoder/decodable-matrix.h" -#include "lm/const-arpa-lm.h" -#include "rnnlm/rnnlm-lattice-rescoring.h" #include "base/timer.h" #include "decoder/lattice-biglm-faster-decoder.h" @@ -162,26 +160,14 @@ int main(int argc, char *argv[]) { bool allow_partial = false; BaseFloat acoustic_scale = 0.1; LatticeBiglmFasterDecoderConfig config; - int32 max_ngram_order = 4; - rnnlm::RnnlmComputeStateComputationOptions rnn_opts; - bool use_carpa = false; - std::string word_syms_filename, word_embedding_rxfilename; + std::string word_syms_filename; config.Register(&po); - rnn_opts.Register(&po); po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods"); po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]"); po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached."); - po.Register("use-const-arpa", &use_carpa, "If true, read the old-LM file " - "as a const-arpa file as opposed to an FST file"); - po.Register("word-embedding-rxfilename", &word_embedding_rxfilename, "If set, use rnnlm"); - po.Register("max-ngram-order", &max_ngram_order, - "If positive, allow RNNLM histories longer than this to be identified " - "with each other for rescoring purposes (an approximation that " - "saves time and reduces output lattice size)."); - - + po.Read(argc, argv); if (po.NumArgs() < 6 || po.NumArgs() > 8) { @@ -201,39 +187,17 @@ int main(int argc, char *argv[]) { TransitionModel trans_model; ReadKaldiObject(model_in_filename, &trans_model); - VectorFst *old_lm_fst = fst::ReadAndPrepareLmFst( - old_lm_fst_rxfilename); - fst::BackoffDeterministicOnDemandFst old_lm_dfst(*old_lm_fst); - fst::ScaleDeterministicOnDemandFst old_lm_sdfst(-1, - &old_lm_dfst); - - fst::DeterministicOnDemandFst* new_lm_dfst = NULL; - VectorFst *new_lm_fst = NULL; - ConstArpaLm* const_arpa = NULL; - CuMatrix* word_embedding_mat = NULL; - kaldi::nnet3::Nnet *rnnlm = NULL; - const rnnlm::RnnlmComputeStateInfo *info = NULL; - - if (word_embedding_rxfilename!="") { - rnnlm = new kaldi::nnet3::Nnet(); - word_embedding_mat = new CuMatrix(); - ReadKaldiObject(word_embedding_rxfilename, word_embedding_mat); - ReadKaldiObject(new_lm_fst_rxfilename, rnnlm); - info = new rnnlm::RnnlmComputeStateInfo(rnn_opts, *rnnlm, *word_embedding_mat); - new_lm_dfst = new rnnlm::KaldiRnnlmDeterministicFst(max_ngram_order, *info); - } else if (use_carpa) { - const_arpa = new ConstArpaLm(); - ReadKaldiObject(new_lm_fst_rxfilename, const_arpa); - new_lm_dfst = new ConstArpaLmDeterministicFst(*const_arpa); - } else { - new_lm_fst = fst::ReadAndPrepareLmFst( - new_lm_fst_rxfilename); - new_lm_dfst = - new fst::BackoffDeterministicOnDemandFst(*new_lm_fst); - } + VectorFst *old_lm_fst = fst::CastOrConvertToVectorFst( + fst::ReadFstKaldiGeneric(old_lm_fst_rxfilename)); + ApplyProbabilityScale(-1.0, old_lm_fst); // Negate old LM probs... + + VectorFst *new_lm_fst = fst::CastOrConvertToVectorFst( + fst::ReadFstKaldiGeneric(new_lm_fst_rxfilename)); - fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_sdfst, - new_lm_dfst); + fst::BackoffDeterministicOnDemandFst old_lm_dfst(*old_lm_fst); + fst::BackoffDeterministicOnDemandFst new_lm_dfst(*new_lm_fst); + fst::ComposeDeterministicOnDemandFst compose_dfst(&old_lm_dfst, + &new_lm_dfst); fst::CacheDeterministicOnDemandFst cache_dfst(&compose_dfst, 1e7); bool determinize = config.determinize_lattice; @@ -257,7 +221,6 @@ int main(int argc, char *argv[]) { double tot_like = 0.0; kaldi::int64 frame_count = 0; int num_success = 0, num_fail = 0; - double elapsed = 0; if (ClassifyRspecifier(fst_in_str, NULL, NULL) == kNoRspecifier) { @@ -292,13 +255,13 @@ int main(int argc, char *argv[]) { num_success++; } else num_fail++; } - elapsed = timer.Elapsed(); } delete decode_fst; // delete this only after decoder goes out of scope. } else { // We have different FSTs for different utterances. assert(0); } + double elapsed = timer.Elapsed(); KALDI_LOG << "Time taken "<< elapsed << "s: real-time factor assuming 100 frames/sec is " << (elapsed*100.0/frame_count); @@ -308,14 +271,6 @@ int main(int argc, char *argv[]) { << frame_count<<" frames."; delete word_syms; - - delete const_arpa; - delete new_lm_fst; - delete new_lm_dfst; - delete word_embedding_mat; - delete rnnlm; - delete info; - if (num_success != 0) return 0; else return 1; } catch(const std::exception &e) { diff --git a/src/decoder/lattice-biglm-faster-decoder.h b/src/decoder/lattice-biglm-faster-decoder.h index 15abe4c0482..029ecb2b299 100644 --- a/src/decoder/lattice-biglm-faster-decoder.h +++ b/src/decoder/lattice-biglm-faster-decoder.h @@ -101,7 +101,6 @@ class LatticeBiglmFasterDecoder { active_toks_[0].toks = start_tok; toks_.Insert(start_pair, start_tok); num_toks_++; - propage_lm_num_=0; ProcessNonemitting(0); // We use 1-based indexing for frames in this decoder (if you view it in @@ -119,7 +118,6 @@ class LatticeBiglmFasterDecoder { else if (frame % config_.prune_interval == 0) PruneActiveTokens(frame, config_.lattice_beam * 0.1); // use larger delta. } - KALDI_VLOG(1) << "propage_lm_num_: " << propage_lm_num_; // Returns true if we have any kind of traceback available (not necessarily // to the end state; query ReachedFinal() for that). return !final_costs_.empty(); @@ -570,12 +568,7 @@ class LatticeBiglmFasterDecoder { } } } - int32 ToksNum(int32 f) { - int32 c=0; - for (Token *t=active_toks_[f].toks; t; t=t->next) c++; - return c; - } - + // Go backwards through still-alive tokens, pruning them. note: cur_frame is // where hash toks_ are (so we do not want to mess with it because these tokens // don't yet have forward pointers), but we do all previous frames, unless we @@ -607,7 +600,6 @@ class LatticeBiglmFasterDecoder { } KALDI_VLOG(3) << "PruneActiveTokens: pruned tokens from " << num_toks_begin << " to " << num_toks_; - KALDI_VLOG(2) << "expand fr num: " << cur_frame-config_.prune_interval << " " << ToksNum(cur_frame-config_.prune_interval); } // Version of PruneActiveTokens that we call on the final frame. @@ -686,7 +678,6 @@ class LatticeBiglmFasterDecoder { if (arc->olabel == 0) { return lm_state; // no change in LM state if no word crossed. } else { // Propagate in the LM-diff FST. - propage_lm_num_++; Arc lm_arc; bool ans = lm_diff_fst_->GetArc(lm_state, arc->olabel, &lm_arc); if (!ans) { // this case is unexpected for statistical LMs. @@ -908,7 +899,6 @@ class LatticeBiglmFasterDecoder { active_toks_.clear(); KALDI_ASSERT(num_toks_ == 0); } - uint64 propage_lm_num_; }; } // end namespace kaldi. diff --git a/src/decoder/lattice2-biglm-faster-decoder.cc b/src/decoder/lattice2-biglm-faster-decoder.cc index 9cd02b8c6d6..fad92bb0ab0 100644 --- a/src/decoder/lattice2-biglm-faster-decoder.cc +++ b/src/decoder/lattice2-biglm-faster-decoder.cc @@ -1,6 +1,7 @@ // decoder/lattice2-biglm-faster-decoder.h -// Copyright 2018 Hang Lyu Zhehuai Chen +// Copyright 2018 Johns Hopkins University (Author: Daniel Povey) +// Hang Lyu // See ../../COPYING for clarification regarding multiple authors // @@ -32,17 +33,39 @@ Lattice2BiglmFasterDecoder::Lattice2BiglmFasterDecoder( lm_diff_fst->Start() != fst::kNoStateId); toks_.SetSize(1000); // just so on the first frame we do something reasonable. for (int i = 0; i < 2; i++) toks_shadowing_[i].SetSize(1000); // just so on the first frame we do something reasonable. + toks_backfill_pair_.resize(0); toks_backfill_hclg_.resize(0); } bool Lattice2BiglmFasterDecoder::Decode(DecodableInterface *decodable) { + // clean up from last time. + DeleteElems(toks_.Clear()); + for (int i = 0; i < 2; i++) DeleteElemsShadow(toks_shadowing_[i]); + ClearActiveTokens(); + + // clean up private members + warned_noarc_ = false; + warned_ = false; + final_active_ = false; + final_costs_.clear(); + num_toks_ = 0; + + // At the beginning of an utterance, initialize. + toks_backfill_pair_.resize(0); + toks_backfill_hclg_.resize(0); + PairId start_pair = ConstructPair(fst_.Start(), lm_diff_fst_->Start()); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, lm_diff_fst_->Start(), fst_.Start()); + active_toks_[0].toks = start_tok; + toks_.Insert(start_pair, start_tok); + toks_shadowing_[NumFramesDecoded()%2].Insert(fst_.Start(), start_tok); + num_toks_++; + ProcessNonemitting(0); - InitDecoding(); // We use 1-based indexing for frames in this decoder (if you view it in // terms of features), but note that the decodable object uses zero-based // numbering, which we have to correct for when we call it. - int32 last_expand_frame=0; for (int32 frame = 1; !decodable->IsLastFrame(frame-2); frame++) { active_toks_.resize(frame+1); // new column @@ -50,30 +73,30 @@ bool Lattice2BiglmFasterDecoder::Decode(DecodableInterface *decodable) { ProcessNonemitting(frame); - if (frame % config_.prune_interval == 0) { + // Update the backward-cost of each token + if (frame % 5 == 0) UpdateBackwardCost(frame, config_.lattice_beam * 0.1); + + //if (decodable->IsLastFrame(frame-1)) + // PruneActiveTokensFinal(frame); + //else if (frame % config_.prune_interval == 0) + // PruneActiveTokens(frame, config_.lattice_beam * 0.1); // use larger delta. + if (frame % config_.prune_interval == 0) PruneActiveTokens(frame, config_.lattice_beam * 0.1); // use larger delta. - } - int32 t = frame-config_.prune_interval-config_.explore_interval; - if (t >= 0 && (frame-config_.explore_interval) % config_.prune_interval == 0) { - KALDI_ASSERT(t==last_expand_frame); - for (; t<=frame; t++) - ExpandShadowTokens(t, frame-config_.explore_interval-1, decodable, t==last_expand_frame); - last_expand_frame=frame-config_.explore_interval; - } // We could add another config option to decide the gap between state passing // and lm passing. + if (frame-config_.prune_interval >= 0) ExpandShadowTokens(frame-config_.prune_interval); } - PruneActiveTokens(NumFramesDecoded(), config_.lattice_beam * 0.1); - - for (int32 t=last_expand_frame; t<=NumFramesDecoded(); t++) - ExpandShadowTokens(t, NumFramesDecoded(), decodable, t==last_expand_frame); - // Process the last few frames lm passing - PruneActiveTokensFinal(NumFramesDecoded(), true); // with sanity check - KALDI_VLOG(1) << "propage_lm_num_: " << propage_lm_expand_num_ << " " << propage_lm_num_; + for (int32 frame = std::max(0, NumFramesDecoded() - config_.prune_interval + 1); // final expand + frame < NumFramesDecoded() + 1; frame++) { + if (frame % 5 == 0) UpdateBackwardCost(frame, config_.lattice_beam * 0.1); + ExpandShadowTokens(frame); + } + + PruneActiveTokensFinal(NumFramesDecoded()); // Returns true if we have any kind of traceback available (not necessarily // to the end state; query ReachedFinal() for that). @@ -82,128 +105,230 @@ bool Lattice2BiglmFasterDecoder::Decode(DecodableInterface *decodable) { bool Lattice2BiglmFasterDecoder::Decode(DecodableInterface *decodable, const Vector &cutoff) { - // //initial cutoff_ - if (cutoff.Dim()) { - cutoff_.Resize(cutoff.Dim()); - cutoff_ = cutoff; - } else { - cutoff_.Resize(1); - cutoff_.Data()[0] = std::numeric_limits::max(); + // clean up from last time. + DeleteElems(toks_.Clear()); + for (int i = 0; i < 2; i++) DeleteElemsShadow(toks_shadowing_[i]); + ClearActiveTokens(); + + // initial cutoff_ + cutoff_.Resize(cutoff.Dim()); + cutoff_ = cutoff; + + // clean up private members + warned_noarc_ = false; + warned_ = false; + final_active_ = false; + final_costs_.clear(); + num_toks_ = 0; + + // At the beginning of an utterance, initialize. + toks_backfill_pair_.resize(0); + toks_backfill_hclg_.resize(0); + PairId start_pair = ConstructPair(fst_.Start(), lm_diff_fst_->Start()); + active_toks_.resize(1); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, lm_diff_fst_->Start(), fst_.Start()); + active_toks_[0].toks = start_tok; + toks_.Insert(start_pair, start_tok); + toks_shadowing_[NumFramesDecoded()%2].Insert(fst_.Start(), start_tok); + num_toks_++; + ProcessNonemitting(0); + + // We use 1-based indexing for frames in this decoder (if you view it in + // terms of features), but note that the decodable object uses zero-based + // numbering, which we have to correct for when we call it. + for (int32 frame = 1; !decodable->IsLastFrame(frame-2); frame++) { + active_toks_.resize(frame+1); // new column + + ProcessEmitting(decodable, frame); + + ProcessNonemitting(frame); + + // Update the backward-cost of each token + if (frame % 5 == 0) UpdateBackwardCost(frame, config_.lattice_beam * 0.1); + + //if (decodable->IsLastFrame(frame-1)) + // PruneActiveTokensFinal(frame); + //else if (frame % config_.prune_interval == 0) + // PruneActiveTokens(frame, config_.lattice_beam * 0.1); // use larger delta. + if (frame % config_.prune_interval == 0) + PruneActiveTokens(frame, config_.lattice_beam * 0.1); // use larger delta. + + + // We could add another config option to decide the gap between state passing + // and lm passing. + if (frame-config_.prune_interval >= 0) ExpandShadowTokens(frame-config_.prune_interval); + } + + // Process the last few frames lm passing + for (int32 frame = std::max(0, NumFramesDecoded() - config_.prune_interval + 1); // final expand + frame < NumFramesDecoded() + 1; frame++) { + if (frame % 5 == 0) UpdateBackwardCost(frame, config_.lattice_beam * 0.1); + ExpandShadowTokens(frame); } - return Decode(decodable); + + PruneActiveTokensFinal(NumFramesDecoded()); + + // Returns true if we have any kind of traceback available (not necessarily + // to the end state; query ReachedFinal() for that). + return !final_costs_.empty(); } -void Lattice2BiglmFasterDecoder::ExpandShadowTokens(int32 cur_frame, int32 frame_stop_expand, DecodableInterface *decodable, bool first) { - Timer timer; - - expanding_=true; - bool is_last = cur_frame <= frame_stop_expand; // the last time we do expand in this frame - KALDI_ASSERT(cur_frame >= 0); - auto& cur_q = GetExpandQueue(cur_frame); - auto& cur_h = GetBackfillMap(cur_frame); - if (cur_frame > frame_stop_expand && !cur_q.size()) { - expanding_=false; - return; +void Lattice2BiglmFasterDecoder::ExpandShadowTokens(int32 frame) { + assert(frame >= 0); + + BuildBackfillMap(frame); + if ( (frame + 1) <= active_toks_.size()) { + BuildBackfillMap(frame+1); } - if (first) BuildBackfillMap(cur_frame, frame_stop_expand, first); - if ( (cur_frame + 1) < active_toks_.size()) { - BuildBackfillMap(cur_frame+1, frame_stop_expand, true); + if (active_toks_[frame].toks == NULL) { + KALDI_WARN << "ExpandShadowTokens: no tokens active on frame " << frame; } - while (!cur_q.empty()) { - auto q_elem= cur_q.front(); - cur_q.pop(); - Token* tok = q_elem.first; - tok->in_queue=false; - bool cur_better_hclg = q_elem.second; - int32 frame = cur_frame; - BaseFloat cur_cutoff = (frame+1 < cutoff_.Dim())? -cutoff_(frame+1) : std::numeric_limits::infinity(); - - if (tok->tot_cost > cur_cutoff) { - tok->shadowing_tok = NULL; // already expand - continue; - } + BaseFloat cur_cutoff = cutoff_(frame); - ForwardLink *link=NULL, *links_to_clear=NULL; - if (tok->shadowing_tok == NULL) { - // if we need to update a shadowing token itself - link=tok->links; - tok->links=NULL; // we firstly un-hook it from shadowing token - links_to_clear=link; // we will reconstruct it and delete the original one later - } else { - // otherwise, we are updating a shadowed token - // we obtain template links from shadowing token - // sanity check: - // KALDI_ASSERT(toks_backfill_hclg_[frame]->find(tok->hclg_state)->second==tok->shadowing_tok); - KALDI_ASSERT(!tok->links); - Token* shadowing_tok = tok; - while (shadowing_tok->shadowing_tok && !shadowing_tok->links) shadowing_tok = shadowing_tok->shadowing_tok; - // Update toks_shadowing_mod for better_hclg here - // Notice that we only update if it reaches NumFramesDecoded(), since it will affect explore - if (frame == NumFramesDecoded() && *tok > *shadowing_tok) { - HashList &toks_shadowing_mod=toks_shadowing_[frame%2]; - ElemShadow *elem = toks_shadowing_mod.Find(tok->hclg_state); - if (elem) { - if (*tok < *elem->val) { - elem->val = tok; - // sanity check - // KALDI_ASSERT(!(*toks_backfill_hclg_[frame]).find(tok->hclg_state)->second->shadowing_tok); - } - } else // from better_hclg - toks_shadowing_mod.Insert(tok->hclg_state, tok); - } - link = shadowing_tok->links; - if (!link) { - // for the end of decoding, we need to expand all - if (is_last) - tok->shadowing_tok = NULL; - else if (frame == NumFramesDecoded()) { - auto iter = *toks_backfill_hclg_[frame]->find(tok->hclg_state); - KALDI_ASSERT(iter.second != tok); - if (*iter.second > *tok) // better_hclg - tok->shadowing_tok = NULL; - else - KALDI_ASSERT(tok->shadowing_tok == iter.second); - } // for normal shadowed token is_last==false, we process it later - continue; + /* + // Find the minimum backward cost in this frame + BaseFloat best_backward_cost = std::numeric_limits::infinity(); + BaseFloat worst_backward_cost = std::numeric_limits::lowest(); + for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + best_backward_cost = std::min(best_backward_cost, tok->backward_cost); + if (tok->backward_cost != std::numeric_limits::infinity()) + worst_backward_cost = std::max(worst_backward_cost, tok->backward_cost); + } + std::cout << "In frame " << frame + << " Best backward cost is " << best_backward_cost + << " Worst backward cost is " << worst_backward_cost << std::endl; + */ + /* + BaseFloat best_forward_cost = std::numeric_limits::infinity(); + BaseFloat worst_forward_cost = std::numeric_limits::lowest(); + for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + best_forward_cost = std::min(best_forward_cost, tok->tot_cost); + worst_forward_cost = std::max(worst_forward_cost, tok->tot_cost); + } + std::cout << "In frame " << frame + << " Best forward cost is " << best_forward_cost + << " Worst forward cost is " << worst_forward_cost << std::endl; + */ + + /* + BaseFloat best_fb_cost = std::numeric_limits::infinity(); + for (Token* tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + best_fb_cost = std::min(best_fb_cost, tok->tot_cost + tok->backward_cost); + } + */ + + /* + std::cout << "Show pair map." << std::endl; + for(unordered_map::iterator iter = toks_backfill_pair_[frame]->begin(); + iter != toks_backfill_pair_[frame]->end(); iter++) { + Token* tok = iter->second; + if (tok->shadowing_tok) { + std::cout << "Token is (" << tok->hclg_state << "," << tok->lm_state << "). And Shadowing token is (" + << tok->shadowing_tok->hclg_state << "," << tok->shadowing_tok->lm_state << ")" << std::endl; + } else { + std::cout << "Token is (" << tok->hclg_state << "," << tok->lm_state << ")" << std::endl; } } + */ + + // When we expand, if the arc.ilabel == 0, there maybe a new token is + // created in current frame. The queue "expand_current_frame_queue" is used + // to deal with all tokens in current frame in loop. + std::queue expand_current_frame_queue; + // Initialize the "expand_current_frame_queue" with current tokens. + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + if (!tok->shadowing_tok) continue; // shadowing token + // If shadowing_tok == NULL, it means + // this token has been processed. Or it + // is the best one in this HCLG state + // in the frame. Skip it. + // Decide which token should be expanded + if (tok->tot_cost <= cur_cutoff) { + expand_current_frame_queue.push(ConstructPair(tok->hclg_state, + tok->lm_state)); + } + } - if (cur_better_hclg && config_.better_hclg==2) { - for (fst::ArcIterator > aiter(fst_, tok->hclg_state); - !aiter.Done(); - aiter.Next()) { - Arc arc = aiter.Value(); - StateId ilabel = arc.ilabel; - int32 new_frame_index = ilabel ? frame+1 : frame; - if (new_frame_index > NumFramesDecoded()) continue; - BaseFloat graph_cost_ori = arc.weight.Value(); - StateId new_hclg_state = arc.nextstate; - Arc new_arc(arc); - StateId new_lm_state = PropagateLm(tok->lm_state, &arc); // may affect "arc.weight". - BaseFloat ac_cost = ilabel ? -decodable->LogLikelihood(frame, ilabel) : 0, - graph_cost = new_arc.weight.Value(), - cur_cost = tok->tot_cost, - tot_cost = cur_cost + ac_cost + graph_cost; - - BaseFloat extra_cost = tok->extra_cost, // TODO - backward_cost = tok->backward_cost; + while (!expand_current_frame_queue.empty()) { + PairId cur_id = expand_current_frame_queue.front(); + expand_current_frame_queue.pop(); + + KALDI_ASSERT(toks_backfill_pair_[frame]->find(cur_id) != + toks_backfill_pair_[frame]->end()); - // prepare to store a new token in the current / next frame - if (new_frame_index+1 < cutoff_.Dim() && - tot_cost > cutoff_(new_frame_index+1)) continue; - if (extra_cost > config_.lattice_beam) continue; - Token* new_tok = ExpandShadowTokensSub(ilabel, new_hclg_state, new_lm_state, frame, new_frame_index, tot_cost, extra_cost, backward_cost, is_last); - // create lattice arc - tok->links = new ForwardLink(new_tok, arc.ilabel, arc.olabel, - graph_cost, ac_cost, tok->links, - graph_cost_ori); + Token* tok = (*toks_backfill_pair_[frame])[cur_id]; + //if (tok->tot_cost > best_forward_cost + config_.expand_beam) continue; + if (tok->tot_cost > cur_cutoff) + continue; + + Token* shadowing_tok = tok->shadowing_tok; + + /* + std::cout << "Expanding: Frame " << frame + << " cur_token HCLG_id is " << tok->hclg_state + << " LM_id is " << tok->lm_state + << " .And shadowing_token HCLG_id is " << shadowing_tok->hclg_state + << " LM_id is " << shadowing_tok->lm_state << std::endl; + */ + for (ForwardLink *link = shadowing_tok->links; link != NULL; link = link->next) { + Token *next_tok = link->next_tok; + + Arc arc(link->ilabel, link->olabel, link->graph_cost_ori, 0); + StateId new_hclg_state = next_tok->hclg_state; + StateId new_lm_state = PropagateLm(tok->lm_state, &arc); // may affect "arc.weight". + PairId new_pair = ConstructPair(new_hclg_state, new_lm_state); + BaseFloat ac_cost = link->acoustic_cost, + graph_cost = arc.weight.Value(), + cur_cost = tok->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + // The extra_cost and backward_cost are temporary. They are inherited from + // "next_tok" which is the destation of "shadowing_token". So they are + // estimated rather than exact. They will be used to initialize a new + // token and help to decide the new token will be expanded or not, as the + // backward_cost value will be updated periodly rather than frame-by-frame. + BaseFloat extra_cost = next_tok->extra_cost, + backward_cost = next_tok->backward_cost; + + // prepare to store a new token in the current / next frame + int32 new_frame_index = link->ilabel ? frame+1 : frame; + Token *&toks = link->ilabel ? active_toks_[frame+1].toks : active_toks_[frame].toks; + assert(toks); + + /* + std::cout << "ilabel is " << link->ilabel + << " Next token HCLG_id is " << next_tok->hclg_state + << " LM_id is " << next_tok->lm_state + << " extra_cost is " << next_tok->extra_cost + << " New token HCLG_id is " << new_hclg_state + << " LM_id is " << new_lm_state + << " extra_cost is " << extra_cost << std::endl; + */ + + bool exist_flag = false; + if (toks_backfill_pair_[new_frame_index]->find(new_pair) != + toks_backfill_pair_[new_frame_index]->end()) { + exist_flag = true; } - } else { + + // Special case: An arc that we expand in backfill reaches an existing + // state but it gives that state a better forward cost than before. + /* + if (exist_flag && tot_cost < + (*toks_backfill_pair_[new_frame_index])[new_pair]->tot_cost) { + // Update the destination token + ProcessBetterExistingToken(new_frame_index, new_pair, tot_cost); + } + */ + if (exist_flag && tot_cost < + (*toks_backfill_pair_[new_frame_index])[new_pair]->tot_cost) { + // Update the destination token + (*toks_backfill_pair_[new_frame_index])[new_pair]->tot_cost = tot_cost; + } + // There will be four kinds of links need to be processed. // 1. Go to next frame and the corresponding "next_tok" is shadowed // 2. Go to next frame and the corresponding "next_tok" is the processed @@ -215,55 +340,81 @@ cutoff_(frame+1) : std::numeric_limits::infinity(); // processed.(Under most circumstances, it is the best one and processed // in explore step) // However, the way to deal with them is similar. - for (; link != NULL; link = link->next) { - Token *next_tok = link->next_tok; - while (next_tok->shadowing_tok && !next_tok->links) next_tok=next_tok->shadowing_tok; - StateId ilabel = link->ilabel; - int32 new_frame_index = ilabel ? frame+1 : frame; - if (new_frame_indexlinks) continue; // this link should be pruned - - Arc arc(ilabel, link->olabel, link->graph_cost_ori, 0); - BaseFloat graph_cost_ori = link->graph_cost_ori; // TODO - StateId new_hclg_state = next_tok->hclg_state; - StateId new_lm_state = PropagateLm(tok->lm_state, &arc); // may affect "arc.weight". - BaseFloat ac_cost = link->acoustic_cost, - graph_cost = arc.weight.Value(), - cur_cost = tok->tot_cost, - tot_cost = cur_cost + ac_cost + graph_cost; - - // The extra_cost and backward_cost are temporary. They are inherited from - // "next_tok" which is the destation of "shadowing_token". So they are - // estimated rather than exact. They will be used to initialize a new - // token and help to decide the new token will be expanded or not - BaseFloat extra_cost = next_tok->extra_cost + tot_cost - next_tok->tot_cost, // inherit backward cost, use its own tot_cost - backward_cost = next_tok->backward_cost; - - // prepare to store a new token in the current / next frame - if (new_frame_index+1 < cutoff_.Dim() && - tot_cost > cutoff_(new_frame_index+1)) continue; - if (extra_cost > config_.lattice_beam) continue; - Token* new_tok = ExpandShadowTokensSub(ilabel, new_hclg_state, new_lm_state, frame, new_frame_index, tot_cost, extra_cost, backward_cost, is_last); - // create lattice arc - tok->links = new ForwardLink(new_tok, arc.ilabel, arc.olabel, - graph_cost, ac_cost, tok->links, - graph_cost_ori); - } // end of for loop - } - while (links_to_clear) { - ForwardLink* l=links_to_clear->next; - delete links_to_clear; - links_to_clear = l; - } + Token* new_tok; + if (!exist_flag) { // A new token. + // Construct the new token. + new_tok = new Token(tot_cost, extra_cost, NULL, toks, new_lm_state, + new_hclg_state, backward_cost); + toks = new_tok; + num_toks_++; + + // Add the new token to "backfill" map. + (*toks_backfill_pair_[new_frame_index])[new_pair] = new_tok; + + new_tok->shadowing_tok = + (*toks_backfill_hclg_[new_frame_index])[new_hclg_state]; + + // Still a shadowed token. Push into queue. + if (link->ilabel == 0 && new_tok->shadowing_tok) { + expand_current_frame_queue.push(new_pair); + } + } else { // An existing token + new_tok = (*toks_backfill_pair_[new_frame_index])[new_pair]; + } + + // create lattice arc + tok->links = new ForwardLink(new_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, tok->links, + link->graph_cost_ori); + + + + // Special case: A previously unseen state was created that has a higher + // probability than an existing copy of the same HCLG.fst state. + /* + if (tot_cost < + (*toks_backfill_hclg_[new_frame_index])[new_hclg_state]->tot_cost) { + ProcessBetterHCLGToken(new_frame_index, new_tok); + // The token has been processed along previous-best-HCLG token + new_tok->shadowing_tok = NULL; + } else { + if (!exist_flag) { + new_tok->shadowing_tok = + (*toks_backfill_hclg_[new_frame_index])[new_hclg_state]; + } + } + */ + } // end of for loop tok->shadowing_tok = NULL; // already expand } // Clean the backfill map - cur_h.clear(); - if (is_last) toks_backfill_hclg_[cur_frame]->clear(); + toks_backfill_pair_[frame]->clear(); + toks_backfill_hclg_[frame]->clear(); - KALDI_VLOG(2) << "expand fr num: " << cur_frame << " " << is_last << " " << GetExpandQueue(cur_frame+1).size() << " " << ToksNum(cur_frame); - expanding_=false; - expand_time_ += timer.Elapsed(); + /* + std::cout << "CHeck the shadowing token in ExpandShadowTokens() on frame " + << frame << std::endl; + for (int32 frame_idx = 0; frame_idx <= active_toks_.size(); frame_idx++) { + for (Token *tok = active_toks_[frame_idx].toks; tok != NULL; tok = tok->next) { + if (tok->shadowing_tok != NULL) { // shadowed token + for (ForwardLink *link = tok->links; link != NULL; + link = link->next) { + if (link->ilabel != 0) { + std::cout << "The bug is in frame " << frame_idx << std::endl; + std::cout << "The error token HCLG_id is " << tok->hclg_state + << " LM_id is " << tok->lm_state + << " .And shadowing token HCLG_id is " << tok->shadowing_tok->hclg_state + << " LM_id is " << tok->shadowing_tok->lm_state << std::endl; + } + KALDI_ASSERT(link->ilabel == 0); + } + } else { // exploring token + + } + } + } + */ } @@ -369,7 +520,7 @@ bool Lattice2BiglmFasterDecoder::GetLattice( void Lattice2BiglmFasterDecoder::PruneForwardLinks(int32 frame, bool *extra_costs_changed, bool *links_pruned, - BaseFloat delta, bool is_expand=false) { + BaseFloat delta) { // delta is the amount by which the extra_costs must change // If delta is larger, we'll tend to go back less far // toward the beginning of the file. @@ -397,29 +548,44 @@ void Lattice2BiglmFasterDecoder::PruneForwardLinks(int32 frame, // will recompute tok_extra_cost for tok. BaseFloat tok_extra_cost = std::numeric_limits::infinity(); // tok_extra_cost is the best (min) of link_extra_cost of outgoing links - if (tok->shadowing_tok && tok->links) { // has been expanded - if (*tok > *tok->shadowing_tok) tok->DeleteForwardLinks(); - else tok->shadowing_tok = NULL; - } for (link = tok->links; link != NULL; ) { // See if we need to excise this link... Token *next_tok = link->next_tok; BaseFloat link_extra_cost = 0.0; - if (is_expand && next_tok->shadowing_tok) { - KALDI_ASSERT(!next_tok->links); - next_tok->shadowing_tok=NULL; // hasn't pruned but it should do - } if (next_tok->shadowing_tok) { - Token* s=next_tok->shadowing_tok; - while (s->shadowing_tok) s=s->shadowing_tok; - link_extra_cost = s->extra_cost + + link_extra_cost = next_tok->shadowing_tok->extra_cost + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - - s->tot_cost); + - next_tok->shadowing_tok->tot_cost); } else { link_extra_cost = next_tok->extra_cost + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - next_tok->tot_cost); // difference in brackets is >= 0 } + /* + Token *next_tok = link->next_tok; + if (next_tok->shadowing_tok) { // excise this link depend on the + // extra_cost of the shadowing token + if (next_tok->shadowing_tok->extra_cost == + std::numeric_limits::infinity()) { // The next_tok + // will be deleted + ForwardLink *next_link = link->next; + if (prev_link != NULL) prev_link->next = link->next; + else tok->links = link->next; + delete link; + link = next_link; // advance link but leave prev_link the same. + *links_pruned = true; + continue; + } else { + prev_link = link; // move to next link + link = link->next; + continue; + } + } + + BaseFloat link_extra_cost = next_tok->extra_cost + + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) + - next_tok->tot_cost); // difference in brackets is >= 0 + */ // link_exta_cost is the difference in score between the best paths // through link source state and through link destination state KALDI_ASSERT(link_extra_cost == link_extra_cost); // check for NaN @@ -428,26 +594,26 @@ void Lattice2BiglmFasterDecoder::PruneForwardLinks(int32 frame, if (prev_link != NULL) prev_link->next = next_link; else tok->links = next_link; delete link; - link = next_link; // advance link but leave prev_link the same. + link = next_link; // advance link but leave prev_link the same. *links_pruned = true; - } else { // keep the link and update the tok_extra_cost if needed. - if (link_extra_cost < 0.0) { // this is just a precaution. + } else { // keep the link and update the tok_extra_cost if needed. + if (link_extra_cost < 0.0) { // this is just a precaution. if (link_extra_cost < -0.01) - //KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; link_extra_cost = 0.0; } if (link_extra_cost < tok_extra_cost) tok_extra_cost = link_extra_cost; - prev_link = link; // move to next link + prev_link = link; // move to next link link = link->next; } - } // for all outgoing links + } // for all outgoing links if (fabs(tok_extra_cost - tok->extra_cost) > delta) - changed = true; // difference new minus old is bigger than delta + changed = true; // difference new minus old is bigger than delta tok->extra_cost = tok_extra_cost; // will be +infinity or <= lattice_beam_. // infinity indicates, that no forward link survived pruning - } // for all Token on active_toks_[frame] + } // for all Token on active_toks_[frame] if (changed) *extra_costs_changed = true; // Note: it's theoretically possible that aggressive compiler @@ -509,7 +675,11 @@ void Lattice2BiglmFasterDecoder::PruneForwardLinksFinal(int32 frame) { for (link = tok->links; link != NULL; ) { // See if we need to excise this link... Token *next_tok = link->next_tok; - KALDI_ASSERT(!next_tok->shadowing_tok); + if (next_tok->shadowing_tok) { // TODO + prev_link = link; // move to next link + link = link->next; + continue; + } BaseFloat link_extra_cost = next_tok->extra_cost + ((tok->tot_cost + link->acoustic_cost + link->graph_cost) - next_tok->tot_cost); @@ -522,7 +692,7 @@ void Lattice2BiglmFasterDecoder::PruneForwardLinksFinal(int32 frame) { } else { // keep the link and update the tok_extra_cost if needed. if (link_extra_cost < 0.0) { // this is just a precaution. if (link_extra_cost < -0.01) - //KALDI_WARN << "Negative extra_cost: " << link_extra_cost; + KALDI_WARN << "Negative extra_cost: " << link_extra_cost; link_extra_cost = 0.0; } if (link_extra_cost < tok_extra_cost) @@ -561,7 +731,7 @@ void Lattice2BiglmFasterDecoder::PruneForwardLinksFinal(int32 frame) { } -void Lattice2BiglmFasterDecoder::PruneTokensForFrame(int32 frame, bool is_expand=false) { +void Lattice2BiglmFasterDecoder::PruneTokensForFrame(int32 frame) { KALDI_ASSERT(frame >= 0 && frame < active_toks_.size()); Token *&toks = active_toks_[frame].toks; if (toks == NULL) @@ -570,10 +740,6 @@ void Lattice2BiglmFasterDecoder::PruneTokensForFrame(int32 frame, bool is_expand // proc shadowed token at first as it needs info from shadowing token for (tok = toks; tok != NULL; tok = next_tok) { next_tok = tok->next; - if (is_expand && tok->shadowing_tok) { - KALDI_ASSERT(!tok->links); - tok->shadowing_tok=NULL; // hasn't pruned but it should do - } if (tok->shadowing_tok) {// shadowed token if (tok->shadowing_tok->extra_cost == std::numeric_limits::infinity()) { // token is unreachable from end of graph; (no forward links survived) @@ -586,10 +752,8 @@ void Lattice2BiglmFasterDecoder::PruneTokensForFrame(int32 frame, bool is_expand prev_tok = tok; // KALDI_ASSERT(tok->shadowing_tok->tot_cost <= tok->tot_cost); // After expanding, sometimes the tok->tot_cost better than shadowing's. - Token* s=tok->shadowing_tok; - while (s->shadowing_tok) s=s->shadowing_tok; - tok->extra_cost = s->extra_cost + - s->tot_cost - tok->tot_cost; + tok->extra_cost = tok->shadowing_tok->extra_cost + + tok->shadowing_tok->tot_cost - tok->tot_cost; } } else { prev_tok = tok; @@ -602,10 +766,6 @@ void Lattice2BiglmFasterDecoder::PruneTokensForFrame(int32 frame, bool is_expand if (tok->extra_cost == std::numeric_limits::infinity()) { // token is unreachable from end of graph; (no forward links survived) // excise tok from list and delete tok. - if (toks_backfill_hclg_.size()>frame && frame >= NumFramesDecoded()-config_.prune_interval) { // the map has been built - if (toks_backfill_hclg_[frame]->erase(tok->hclg_state)) - ; //for (Token* t=toks; t; t=t->next) KALDI_ASSERT(t->shadowing_tok!=tok); // sanity check - } if (prev_tok != NULL) prev_tok->next = tok->next; else toks = tok->next; delete tok; @@ -640,6 +800,10 @@ void Lattice2BiglmFasterDecoder::PruneActiveTokens(int32 cur_frame, if (frame+1 < cur_frame && // except for last frame (no forward links) active_toks_[frame+1].must_prune_tokens) { PruneTokensForFrame(frame+1); + // Check + for (Token *tok = active_toks_[frame+1].toks; tok != NULL; tok = tok->next) { + KALDI_ASSERT(tok->extra_cost != std::numeric_limits::infinity()); + } active_toks_[frame+1].must_prune_tokens = false; } } @@ -648,7 +812,7 @@ void Lattice2BiglmFasterDecoder::PruneActiveTokens(int32 cur_frame, } -void Lattice2BiglmFasterDecoder::PruneActiveTokensFinal(int32 cur_frame, bool is_expand) { +void Lattice2BiglmFasterDecoder::PruneActiveTokensFinal(int32 cur_frame) { // returns true if there were final states active // else returns false and treats all states as final while doing the pruning // (this can be useful if you want partial lattice output, @@ -661,10 +825,10 @@ void Lattice2BiglmFasterDecoder::PruneActiveTokensFinal(int32 cur_frame, bool is for (int32 frame = cur_frame-1; frame >= 0; frame--) { bool b1, b2; // values not used. BaseFloat dontcare = 0.0; // delta of zero means we must always update - PruneForwardLinks(frame, &b1, &b2, dontcare, is_expand); - PruneTokensForFrame(frame+1, is_expand); + PruneForwardLinks(frame, &b1, &b2, dontcare); + PruneTokensForFrame(frame+1); } - PruneTokensForFrame(0, is_expand); + PruneTokensForFrame(0); KALDI_VLOG(3) << "PruneActiveTokensFinal: pruned tokens from " << num_toks_begin << " to " << num_toks_; } @@ -721,79 +885,9 @@ BaseFloat Lattice2BiglmFasterDecoder::GetCutoff(Elem *list_head, } } -Lattice2BiglmFasterDecoder::Token* Lattice2BiglmFasterDecoder::ExpandShadowTokensSub(StateId ilabel, - StateId new_hclg_state, StateId new_lm_state, int32 frame, - int32 new_frame_index, BaseFloat tot_cost, BaseFloat extra_cost, BaseFloat backward_cost, - bool is_last) { - Token *&toks = ilabel ? active_toks_[frame+1].toks : active_toks_[frame].toks; - assert(toks); - - Token *tok_found=NULL; - PairId new_pair = ConstructPair(new_hclg_state, new_lm_state); - auto& next_h = GetBackfillMap(new_frame_index); - auto& next_q = GetExpandQueue(new_frame_index); - auto iter = next_h.find(new_pair); - if (iter != next_h.end()) - tok_found = iter->second; - - Token* new_tok; - bool update_tok=false; - if (!tok_found) { // A new token. - // Construct the new token. - new_tok = new Token(tot_cost, extra_cost, NULL, toks, new_lm_state, - new_hclg_state, backward_cost); - toks = new_tok; - num_toks_++; - - // Add the new token to "backfill" map. - next_h[new_pair] = new_tok; - update_tok=true; - } else { // An existing token - new_tok = tok_found; - if (new_tok->tot_cost > tot_cost) { - new_tok->tot_cost = tot_cost; - new_tok->backward_cost = backward_cost; - new_tok->extra_cost = extra_cost; - update_tok=true; - } - } - - bool better_hclg=false; - KALDI_ASSERT(toks_backfill_hclg_.size() > new_frame_index); - auto iter_hclg = (*toks_backfill_hclg_[new_frame_index]).find(new_hclg_state); - if (iter_hclg != (*toks_backfill_hclg_[new_frame_index]).end()) { - if (tot_cost < iter_hclg->second->tot_cost) - better_hclg=true; // search: "Update toks_shadowing_mod for better_hclg" - // although it is better hclg, we still keep its shadowing token for expanding in the next iter - } else { - (*toks_backfill_hclg_[new_frame_index])[new_hclg_state] = new_tok; - iter_hclg = (*toks_backfill_hclg_[new_frame_index]).find(new_hclg_state); - } - - if (update_tok && !new_tok->in_queue) { - new_tok->shadowing_tok = iter_hclg->second; // by default - if (new_tok->shadowing_tok == new_tok) { - // if new_tok is the shadowing token - // search the comments above regarding to: - // "we need to update a shadowing token itself" - new_tok->shadowing_tok = NULL; - } - if (is_last || better_hclg || new_frame_index == frame) { - if (new_tok->shadowing_tok) { // prepare for forwardlinks updating - // sanity check - // KALDI_ASSERT(!new_tok->shadowing_tok->shadowing_tok || new_tok->shadowing_tok->shadowing_tok != new_tok); - new_tok->DeleteForwardLinks(); - } - next_q.push(QElem(new_tok, better_hclg)); - new_tok->in_queue=true; - } - } - return new_tok; -} void Lattice2BiglmFasterDecoder::ProcessEmitting(DecodableInterface *decodable, int32 frame) { - Timer timer; // Processes emitting arcs for one frame. Propagates from prev_toks_ to cur_toks_. HashList &toks_shadowing_check=toks_shadowing_[(frame-1)%2]; HashList &toks_shadowing_mod=toks_shadowing_[frame%2]; @@ -808,10 +902,6 @@ void Lattice2BiglmFasterDecoder::ProcessEmitting(DecodableInterface *decodable, // TODO PossiblyResizeHash for toks_shadowing_ KALDI_VLOG(6) << "Adaptive beam on frame " << frame << "\t" << NumFramesDecoded() << " is " << adaptive_beam << "\t" << cur_cutoff; - if (cutoff_.Dim()<=frame) { - cutoff_.Resize(frame+1,kCopyData); - cutoff_.Data()[frame]=cur_cutoff; - } BaseFloat next_cutoff = std::numeric_limits::infinity(); @@ -829,7 +919,7 @@ void Lattice2BiglmFasterDecoder::ProcessEmitting(DecodableInterface *decodable, aiter.Next()) { Arc arc = aiter.Value(); if (arc.ilabel != 0) { // propagate.. - PropagateLm(lm_state ,&arc); // may affect "arc.weight". + PropagateLm(lm_state, &arc); // may affect "arc.weight". // We don't need the return value (the new LM state). arc.weight = Times(arc.weight, Weight(-decodable->LogLikelihood(frame-1, arc.ilabel))); @@ -852,13 +942,11 @@ void Lattice2BiglmFasterDecoder::ProcessEmitting(DecodableInterface *decodable, StateId state = PairToState(state_pair), lm_state = PairToLmState(state_pair); Token *tok = e->val; - if (tok->tot_cost < cur_cutoff) { + if (tok->tot_cost <= cur_cutoff) { ElemShadow *elem = toks_shadowing_check.Find(state); assert(elem); - if (elem->val == tok || // explore - !tok->shadowing_tok || - *tok < *tok->shadowing_tok) { // it is generated by better_hclg; otherwise tok->shadowing_tok should be set by pne in the last frame - tok->shadowing_tok = NULL; + if (elem->val == tok) { // explore + KALDI_ASSERT(tok->shadowing_tok == NULL); for (fst::ArcIterator > aiter(fst_, state); !aiter.Done(); aiter.Next()) { @@ -874,14 +962,12 @@ void Lattice2BiglmFasterDecoder::ProcessEmitting(DecodableInterface *decodable, if (tot_cost > next_cutoff) continue; else if (tot_cost + config_.beam < next_cutoff) next_cutoff = tot_cost + config_.beam; // prune by best current token - PairId next_pair = ConstructPair(arc.nextstate, next_lm_state); Token *next_tok = FindOrAddToken(next_pair, frame, tot_cost, true, NULL); // true: emitting, NULL: no change indicator needed ElemShadow *elem = toks_shadowing_mod.Find(arc.nextstate); - if (elem) { - if ((*elem->val) > *next_tok) - elem->val = next_tok; // update it + if (elem && elem->val->tot_cost > tot_cost) { + elem->val = next_tok; // update it } else { toks_shadowing_mod.Insert(arc.nextstate, next_tok); } @@ -893,18 +979,16 @@ void Lattice2BiglmFasterDecoder::ProcessEmitting(DecodableInterface *decodable, } // for all arcs KALDI_ASSERT(tok->shadowing_tok == NULL); // it's shadowing token } else { - KALDI_ASSERT(tok->shadowing_tok && *tok > *tok->shadowing_tok); + KALDI_ASSERT(tok->shadowing_tok != NULL); } } e_tail = e->tail; toks_.Delete(e); // delete Elem } - ta_+=timer.Elapsed(); } void Lattice2BiglmFasterDecoder::ProcessNonemitting(int32 frame) { - Timer timer; // note: "frame" is the same as emitting states just processed. // Processes nonemitting arcs for one frame. Propagates within toks_. @@ -967,9 +1051,8 @@ void Lattice2BiglmFasterDecoder::ProcessNonemitting(int32 frame) { Token *new_tok = FindOrAddToken(next_pair, frame, tot_cost, false, &changed); // false: non-emit ElemShadow *elem = toks_shadowing_mod.Find(arc.nextstate); - if (elem) { - if ((*elem->val) > *new_tok) - elem->val = new_tok; // update it + if (elem && elem->val->tot_cost > tot_cost) { + elem->val = new_tok; // update it } else { toks_shadowing_mod.Insert(arc.nextstate, new_tok); } @@ -982,7 +1065,7 @@ void Lattice2BiglmFasterDecoder::ProcessNonemitting(int32 frame) { } } } // for all arcs - } // if + } // end of if condition } // while queue not empty // Make the "shadowing_tok" pointer point to the best one with the same @@ -996,102 +1079,326 @@ void Lattice2BiglmFasterDecoder::ProcessNonemitting(int32 frame) { cur_tok->shadowing_tok = NULL; } else { cur_tok->shadowing_tok = elem->val; - // sanity check - KALDI_ASSERT(!cur_tok->shadowing_tok->shadowing_tok || cur_tok->shadowing_tok->shadowing_tok != cur_tok); - cur_tok->extra_cost=std::numeric_limits::infinity(); - cur_tok->DeleteForwardLinks(); // since some tok could be shadowed after exploring in the same decoding step } } - BuildHCLGMapFromHash(frame); // do it here to make it consistent - tb_+=timer.Elapsed(); } -void Lattice2BiglmFasterDecoder::BuildHCLGMapFromHash(int32 frame, bool append) { - if (!append) KALDI_ASSERT(toks_backfill_hclg_.size() > frame); - HashList &toks_shadowing_mod=toks_shadowing_[frame%2]; - StateHash *hclg_map = - new StateHash(); - hclg_map->reserve(toks_shadowing_mod.Size()); - for (const ElemShadow *e = toks_shadowing_mod.GetList(); e != NULL; e = e->tail) { - (*hclg_map)[e->key] = e->val; +void Lattice2BiglmFasterDecoder::BuildBackfillMap(int32 frame) { + KALDI_ASSERT(toks_backfill_pair_.size() == toks_backfill_hclg_.size()); + // We have already construct the map when we previous special case. + if (toks_backfill_pair_.size() != frame || + active_toks_.size() == toks_backfill_pair_.size()) { return; } + if (active_toks_[frame].toks == NULL) { + KALDI_WARN << "BuildBackfillMap: no tokens active on frame " << frame; } - if (append) { + // Initialize + std::unordered_map *pair_map = + new std::unordered_map(); + std::unordered_map *hclg_map = + new std::unordered_map(); + + /* + for(Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + PairId cur_pair_id = ConstructPair(tok->hclg_state, tok->lm_state); + StateId cur_hclg_id = tok->hclg_state; + + if (pair_map->find(cur_pair_id) == pair_map->end()) { + (*pair_map)[cur_pair_id] = tok; + } + if (hclg_map->find(cur_hclg_id) != hclg_map->end()) { // already exist + if (tok->tot_cost < (*hclg_map)[cur_hclg_id]->tot_cost) { // better + (*hclg_map)[cur_hclg_id] = tok; + } + } else { // new + (*hclg_map)[cur_hclg_id] = tok; + } + } + + toks_backfill_pair_.push_back(pair_map); toks_backfill_hclg_.push_back(hclg_map); - } else { - std::swap(toks_backfill_hclg_[frame], hclg_map); - delete hclg_map; + */ + + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + const PairId cur_pair_id = ConstructPair(tok->hclg_state, tok->lm_state); + const StateId cur_hclg_id = tok->hclg_state; + + if (pair_map->find(cur_pair_id) == pair_map->end()) { + (*pair_map)[cur_pair_id] = tok; + } + // Check this token is an explored token or not + bool is_explored_tok = false; + for(ForwardLink *link = tok->links; link != NULL; link = link->next) { + if (link->ilabel != 0) { + is_explored_tok = true; + break; + } + } + if (is_explored_tok) { + KALDI_ASSERT(hclg_map->find(cur_hclg_id) == hclg_map->end()); + (*hclg_map)[cur_hclg_id] = tok; + } } - // sanity check - // for (auto i:(*hclg_map)) { - // KALDI_ASSERT(!i.second->shadowing_tok); - // //i.second->links is possible to be NULL since it is possible hasnt been pruned - // } + toks_backfill_pair_.push_back(pair_map); + toks_backfill_hclg_.push_back(hclg_map); } -void Lattice2BiglmFasterDecoder::InitDecoding() { - for (int i=0; i<2; i++) { - toks_backfill_pair_[i].clear(); - expand_current_frame_queue_[i] = std::queue(); + + +void Lattice2BiglmFasterDecoder::ProcessBetterExistingToken( + int32 cur_frame, + PairId new_pair_id, + BaseFloat new_tot_cost) { + if (cur_frame > active_toks_.size() - 1) { // have already expand to newest + // frame + return; } - expanding_ = false; - // clean up from last time. - DeleteElems(toks_.Clear()); - for (int i = 0; i < 2; i++) DeleteElemsShadow(toks_shadowing_[i]); - ClearActiveTokens(); + - cutoff_.Resize(1); - cutoff_.Data()[0] = std::numeric_limits::max(); + BuildBackfillMap(cur_frame); - // clean up private members - warned_noarc_ = false; - warned_ = false; - final_active_ = false; - final_costs_.clear(); - num_toks_ = 0; + KALDI_ASSERT((*toks_backfill_pair_[cur_frame])[new_pair_id]); - // At the beginning of an utterance, initialize. - toks_backfill_hclg_.resize(0); - PairId start_pair = ConstructPair(fst_.Start(), lm_diff_fst_->Start()); - active_toks_.resize(1); - Token *start_tok = new Token(0.0, 0.0, NULL, NULL, lm_diff_fst_->Start(), fst_.Start()); - active_toks_[0].toks = start_tok; - toks_.Insert(start_pair, start_tok); - toks_shadowing_[NumFramesDecoded()%2].Insert(fst_.Start(), start_tok); - num_toks_++; - propage_lm_num_=0; - propage_lm_expand_num_=0; - ProcessNonemitting(0); + Token *ori_tok = (*toks_backfill_pair_[cur_frame])[new_pair_id]; + BaseFloat diff = ori_tok->tot_cost - new_tot_cost; + if (diff <= 0) { + return; + } + + // Update the token + ori_tok->tot_cost = new_tot_cost; + BaseFloat new_extra_cost = ori_tok->extra_cost - diff; + ori_tok->extra_cost = new_extra_cost ? new_extra_cost : 0; + // The "ori_tok" could be a shadowed token or exploring token. + // If shadowing token is NULL, it means the token has already been expanded. + // So we only need to expand the change of tot_cost along each arc. + // If shadowing token isn't NULL, it means the token is shadowed. We will + // expand it. In this circumstance, the "links" maybe empty or not, which + // can be caused by ProcessNonemitting() [The token was the best token in + // certain HCLG state--"A" on this frame. But the processing of other token + // create a better token in this HCLG state--"A", so that both of the two + // tokens did the non-emit expand. The former one is shadowed token with + // link, the latter one was processed in "exploration" step.] + if (ori_tok->shadowing_tok == NULL) { // The token was expanded + // Expand the change of tot_cost along each arc. + for (ForwardLink *link = ori_tok->links; link != NULL; + link = link->next) { + Token *next_tok = link->next_tok; + int32 new_frame_index = link->ilabel ? cur_frame + 1 : cur_frame; + PairId next_pair_id = ConstructPair(next_tok->hclg_state, + next_tok->lm_state); + BaseFloat new_tot_cost = next_tok->tot_cost - diff; + ProcessBetterExistingToken(new_frame_index, next_pair_id, new_tot_cost); + } + // Check the token. It is the new best HCLG token or not + if (ori_tok->tot_cost < + (*toks_backfill_hclg_[cur_frame])[ori_tok->hclg_state]->tot_cost) { + // Get the Best HCLG token + Token* pre_best_hclg_tok = + (*toks_backfill_hclg_[cur_frame])[ori_tok->hclg_state]; + + // Update the hclg list + (*toks_backfill_hclg_[cur_frame])[ori_tok->hclg_state] = ori_tok; + + // Update the exploring list + if (cur_frame == active_toks_.size() - 1) { + HashList &toks_shadowing_mod = + toks_shadowing_[cur_frame % 2]; + StateId state = PairToState(new_pair_id); + ElemShadow *elem = toks_shadowing_mod.Find(state); + KALDI_ASSERT(elem); + KALDI_ASSERT(elem->val == pre_best_hclg_tok); + elem->val = ori_tok; + pre_best_hclg_tok->shadowing_tok = ori_tok; + } + } + } else { // The token hasn't been expanded + if (ori_tok->links == NULL) { + // If the token is the new best HCLG token, go on expanding. + if (ori_tok->tot_cost < + (*toks_backfill_hclg_[cur_frame])[ori_tok->hclg_state]->tot_cost) { + // Inside, the ori_tok->shadowing_tok will be set to NULL. + ProcessBetterHCLGToken(cur_frame, ori_tok); + } + } else { // An unexpanded token with links + // Handle the links + for (ForwardLink *link = ori_tok->links; link != NULL; + link = link->next) { + Token *next_tok = link->next_tok; + KALDI_ASSERT(link->ilabel == 0); // These links should be generated + // from ProcessNonemitting(). + PairId next_pair_id = ConstructPair(next_tok->hclg_state, + next_tok->lm_state); + BaseFloat new_tot_cost = next_tok->tot_cost - diff; + ProcessBetterExistingToken(cur_frame, next_pair_id, + new_tot_cost); + } + // Check the token. + if (ori_tok->tot_cost < + (*toks_backfill_hclg_[cur_frame])[ori_tok->hclg_state]->tot_cost) { + // Inside, the ori_tok->shadowing_tok will be set to NULL. + ProcessBetterHCLGToken(cur_frame, ori_tok); + } + } + } } -void Lattice2BiglmFasterDecoder::BuildBackfillMap(int32 frame, int32 frame_stop_expand, bool clear) { - PairHash *pair_map = &GetBackfillMap(frame); - std::queue& q = GetExpandQueue(frame); - if (clear) - pair_map->clear(); - BaseFloat cur_cutoff = (frame+1 < cutoff_.Dim())? -cutoff_(frame+1) : std::numeric_limits::infinity(); +void Lattice2BiglmFasterDecoder::ProcessBetterHCLGToken(int32 cur_frame, + Token *better_token) { + if (cur_frame > active_toks_.size() - 1) { // have already expand to newest + // frame + return; + } + + BuildBackfillMap(cur_frame); + BuildBackfillMap(cur_frame+1); + + // Get previous best one + Token *pre_best_hclg_tok = + (*toks_backfill_hclg_[cur_frame])[better_token->hclg_state]; - for(Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { - if (tok->tot_cost > cur_cutoff) { - tok->shadowing_tok=NULL; - tok->in_queue=false; + if (better_token == pre_best_hclg_tok) { + return; + } + // iterator all links in previous best token + for (ForwardLink *link = pre_best_hclg_tok->links; link != NULL; + link = link->next) { + Token *next_tok = link->next_tok; + Arc arc(link->ilabel, link->olabel, link->graph_cost_ori, 0); + StateId new_hclg_state = next_tok->hclg_state; + StateId new_lm_state = PropagateLm(better_token->lm_state, &arc); // may affect "arc.weight". + PairId new_pair = ConstructPair(new_hclg_state, new_lm_state); + BaseFloat ac_cost = link->acoustic_cost, + graph_cost = arc.weight.Value(), + cur_cost = better_token->tot_cost, + tot_cost = cur_cost + ac_cost + graph_cost; + // Use the extra_cost of pre_best_hlcg_tok to prune the new token + const BaseFloat ref_extra_cost = next_tok->extra_cost; + BaseFloat extra_cost = ref_extra_cost + + (tot_cost - next_tok->tot_cost); + if (extra_cost > config_.beam) { // skip this link continue; } - PairId cur_pair_id = ConstructPair(tok->hclg_state, tok->lm_state); - - bool ok = pair_map->insert({cur_pair_id, tok}).second; - if (frame <= frame_stop_expand) { // need to expand - if (ok) { // without this tok before - if (tok->shadowing_tok) { - q.push(QElem(tok, false)); - tok->in_queue=true; - } else tok->in_queue = false; - } else KALDI_ASSERT(tok->in_queue); // tok has been pushed by ExpandShadowTokens + // prepare to store a new token in the current / next frame + int32 new_frame_index = link->ilabel ? cur_frame+1 : cur_frame; + Token *&toks = active_toks_[new_frame_index].toks; + assert(toks); + + bool exist_flag = false; + if (toks_backfill_pair_[new_frame_index]->find(new_pair) != + toks_backfill_pair_[new_frame_index]->end()) { + exist_flag = true; + } + // Special case: An arc that we expand in backfill reaches an existing + // state,but it gives that state a better forward cost than before. + if (exist_flag && tot_cost < + (*toks_backfill_pair_[new_frame_index])[new_pair]->tot_cost) { + // Update the destination token + ProcessBetterExistingToken(new_frame_index, new_pair, tot_cost); } + Token* new_tok; + if (!exist_flag) { // A new token. + // Construct the new token. + new_tok = new Token(tot_cost, extra_cost, NULL, toks, new_lm_state, + new_hclg_state); + new_tok->shadowing_tok = + (*toks_backfill_hclg_[new_frame_index])[new_hclg_state]; + toks = new_tok; + num_toks_++; + // Add the new token to "backfill" map. + (*toks_backfill_pair_[new_frame_index])[new_pair] = new_tok; + } else { // An existing token + new_tok = (*toks_backfill_pair_[new_frame_index])[new_pair]; + } + // create lattice arc + better_token->links = new ForwardLink(new_tok, arc.ilabel, arc.olabel, + graph_cost, ac_cost, better_token->links, + link->graph_cost_ori); + + // Special case: A previously unseen state was created that has a higher + // probability than an existing copy of the same HCLG.fst state. + if (tot_cost < + (*toks_backfill_hclg_[new_frame_index])[new_hclg_state]->tot_cost) { + ProcessBetterHCLGToken(new_frame_index, new_tok); + } else { + if (!exist_flag) { + new_tok->shadowing_tok = + (*toks_backfill_hclg_[new_frame_index])[new_hclg_state]; + } + } + } // end of for loop + // As the expanding process may cause "best_hclg_tok" change, so we + // reacquire it. + pre_best_hclg_tok = + (*toks_backfill_hclg_[cur_frame])[better_token->hclg_state]; + + // Update the exploring list + if (cur_frame == active_toks_.size() - 1) { + HashList &toks_shadowing_mod = + toks_shadowing_[cur_frame % 2]; + StateId state = better_token->hclg_state; + ElemShadow *elem = toks_shadowing_mod.Find(state); + KALDI_ASSERT(elem); + KALDI_ASSERT(elem->val == pre_best_hclg_tok); + elem->val = better_token; + pre_best_hclg_tok->shadowing_tok = better_token; } + + // Update + (*toks_backfill_hclg_[cur_frame])[better_token->hclg_state] = better_token; + better_token->shadowing_tok = NULL; // already expand } + +void Lattice2BiglmFasterDecoder::UpdateBackwardCost(int32 cur_frame, + BaseFloat delta) { + // Set the backward-cost of current frame tokens to zero + for(Token* tok = active_toks_[cur_frame].toks; tok != NULL; tok = tok->next) { + tok->backward_cost = 0; + } + + for (int32 frame = cur_frame-1; + frame >= 0 && frame > cur_frame - config_.prune_interval; frame--) { + if (active_toks_[frame].toks == NULL ) { // empty list; should not happen. + KALDI_WARN << "No tokens alive when compute backward cost\n"; + } + + // We have to iterate until there is no more change, because the links + // are not guaranteed to be in topological order. + bool changed = true; // difference new minus old extra cost >= delta ? + while (changed) { + changed = false; + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + ForwardLink *link; + // will recompute tok_extra_cost for tok. + BaseFloat tok_backward_cost = std::numeric_limits::infinity(); + // tok_backward_cost is the best (min) of link_extra_cost of outgoing links + for (link = tok->links; link != NULL; link = link->next) { + Token *next_tok = link->next_tok; + BaseFloat link_backward_cost = + std::numeric_limits::infinity(); + if (next_tok->shadowing_tok) { + link_backward_cost = next_tok->shadowing_tok->backward_cost + + link->acoustic_cost + link->graph_cost; + } else { + link_backward_cost = next_tok->backward_cost + link->acoustic_cost + + link->graph_cost; + } + KALDI_ASSERT(link_backward_cost == link_backward_cost); // check for NaN + tok_backward_cost = std::min(tok_backward_cost, link_backward_cost); + } + if (fabs(tok_backward_cost - tok->backward_cost) > delta) + changed = true; // difference new minus old is bigger than delta + tok->backward_cost = tok_backward_cost; + } // for all Token on active_toks_[frame] + } + } + // Set the backward-cost of current frame tokens back to infinity + for (Token* tok = active_toks_[cur_frame].toks; tok != NULL; tok = tok->next) { + tok->backward_cost = std::numeric_limits::infinity(); + } +} + } diff --git a/src/decoder/lattice2-biglm-faster-decoder.h b/src/decoder/lattice2-biglm-faster-decoder.h index 26bc7388c07..c049764d66f 100644 --- a/src/decoder/lattice2-biglm-faster-decoder.h +++ b/src/decoder/lattice2-biglm-faster-decoder.h @@ -1,6 +1,7 @@ // decoder/lattice2-biglm-faster-decoder.h -// Copyright 2018 Hang Lyu Zhehuai Chen +// Copyright 2018 Zhehuai Chen +// Hang Lyu // See ../../COPYING for clarification regarding multiple authors // @@ -28,72 +29,11 @@ #include "fstext/fstext-lib.h" #include "lat/kaldi-lattice.h" #include "decoder/lattice-faster-decoder.h" // for options. -#include "base/timer.h" namespace kaldi { -struct Lattice2BiglmFasterDecoderConfig{ - BaseFloat beam; - int32 max_active; - int32 min_active; - BaseFloat lattice_beam; - int32 prune_interval; - bool determinize_lattice; // not inspected by this class... used in - // command-line program. - BaseFloat beam_delta; // has nothing to do with beam_ratio - BaseFloat hash_ratio; - BaseFloat expand_beam; - BaseFloat prune_scale; // Note: we don't make this configurable on the command line, - // it's not a very important parameter. It affects the - // algorithm that prunes the tokens as we go. - // Most of the options inside det_opts are not actually queried by the - // LatticeFasterDecoder class itself, but by the code that calls it, for - // example in the function DecodeUtteranceLatticeFaster. - fst::DeterminizeLatticePhonePrunedOptions det_opts; - int better_hclg; - int explore_interval; - - Lattice2BiglmFasterDecoderConfig(): beam(16.0), - max_active(std::numeric_limits::max()), - min_active(200), - lattice_beam(10.0), - prune_interval(25), - determinize_lattice(true), - beam_delta(0.5), - hash_ratio(2.0), - expand_beam(16.0), - prune_scale(0.1), - better_hclg(false), explore_interval(0) { } - void Register(OptionsItf *opts) { - det_opts.Register(opts); - opts->Register("beam", &beam, "Decoding beam. Larger->slower, more accurate."); - opts->Register("max-active", &max_active, "Decoder max active states. Larger->slower; " - "more accurate"); - opts->Register("min-active", &min_active, "Decoder minimum #active states."); - opts->Register("lattice-beam", &lattice_beam, "Lattice generation beam. Larger->slower, " - "and deeper lattices"); - opts->Register("prune-interval", &prune_interval, "Interval (in frames) at " - "which to prune tokens"); - opts->Register("determinize-lattice", &determinize_lattice, "If true, " - "determinize the lattice (lattice-determinization, keeping only " - "best pdf-sequence for each word-sequence)."); - opts->Register("beam-delta", &beam_delta, "Increment used in decoding-- this " - "parameter is obscure and relates to a speedup in the way the " - "max-active constraint is applied. Larger is more accurate."); - opts->Register("hash-ratio", &hash_ratio, "Setting used in decoder to " - "control hash behavior"); - opts->Register("expand-beam", &expand_beam, "Expanding beam."); - opts->Register("better-hclg", &better_hclg, "Expanding better HCLG states."); - opts->Register("explore-interval", &explore_interval, "the interval between explore and expand."); - } - void Check() const { - KALDI_ASSERT(beam > 0.0 && max_active > 1 && lattice_beam > 0.0 - && prune_interval > 0 && beam_delta > 0.0 && hash_ratio >= 1.0 - && prune_scale > 0.0 && prune_scale < 1.0); - } -}; - - +// The options are the same as for lattice-faster-decoder.h for now. +typedef LatticeFasterDecoderConfig Lattice2BiglmFasterDecoderConfig; /** This is as LatticeFasterDecoder, but does online composition between HCLG and the "difference language model", which is a deterministic @@ -131,9 +71,9 @@ class Lattice2BiglmFasterDecoder { ClearActiveTokens(); // Clean up backfill map for (int32 frame = NumFramesDecoded(); frame >= 0; frame--) { + delete toks_backfill_pair_[frame]; delete toks_backfill_hclg_[frame]; } - KALDI_VLOG(1) << "time: " << expand_time_ << " " << propage_time_<< " " << ta_ << " " << tb_; } inline int32 NumFramesDecoded() const { return active_toks_.size() - 1; } @@ -167,7 +107,7 @@ class Lattice2BiglmFasterDecoder { // Furthermore, the expanding will be related to current frame and next frame. // For judging the token has better cost or reaching existing token, we build // the backfill maps. - void ExpandShadowTokens(int32 frame, int32 frame_stop_expand, DecodableInterface *decodable, bool first=false); + void ExpandShadowTokens(int32 frame); /// says whether a final-state was active on the last frame. If it was not, the /// lattice (or traceback) will end with states that are not final-states. @@ -265,21 +205,20 @@ class Lattice2BiglmFasterDecoder { // will be expanded. In another word, if we prune the // lattice on each frame rather than prune it periodly, // we only expand the survived tokens after pruning. - bool in_queue; inline Token(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLink *links, Token *next, StateId lm_state, StateId hclg_state): tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), shadowing_tok(NULL), lm_state(lm_state), hclg_state(hclg_state), - backward_cost(std::numeric_limits::infinity()), in_queue(0) {} + backward_cost(std::numeric_limits::infinity()) {} inline Token(BaseFloat tot_cost, BaseFloat extra_cost, ForwardLink *links, Token *next, StateId lm_state, StateId hclg_state, BaseFloat backward_cost): tot_cost(tot_cost), extra_cost(extra_cost), links(links), next(next), shadowing_tok(NULL), lm_state(lm_state), - hclg_state(hclg_state), backward_cost(backward_cost), in_queue(0) {} + hclg_state(hclg_state), backward_cost(backward_cost) {} inline void DeleteForwardLinks() { @@ -291,12 +230,6 @@ class Lattice2BiglmFasterDecoder { } links = NULL; } - inline bool operator < (const Token &other) const { - if (tot_cost == other.tot_cost) // this is important to garrenttee a single shadowing token - return lm_state < other.lm_state; - else return tot_cost < other.tot_cost; - } - inline bool operator > (const Token &other) const { return other < (*this); } }; // head and tail of per-frame list of Tokens (list is in topological order), @@ -318,8 +251,6 @@ class Lattice2BiglmFasterDecoder { if (new_sz > toks_.Size()) { toks_.SetSize(new_sz); } - HashList &h = toks_shadowing_[NumFramesDecoded()%2]; - if (new_sz > h.Size()) h.SetSize(new_sz); } // FindOrAddToken either locates a token in hash of toks_, @@ -371,7 +302,7 @@ class Lattice2BiglmFasterDecoder { // it's called by PruneActiveTokens // all links, that have link_extra_cost > lattice_beam are pruned void PruneForwardLinks(int32 frame, bool *extra_costs_changed, - bool *links_pruned, BaseFloat delta, bool is_expand); + bool *links_pruned, BaseFloat delta); // PruneForwardLinksFinal is a version of PruneForwardLinks that we call // on the final frame. If there are final tokens active, it uses @@ -382,7 +313,7 @@ class Lattice2BiglmFasterDecoder { // [we don't do this in PruneForwardLinks because it would give us // a problem with dangling pointers]. // It's called by PruneActiveTokens if any forward links have been pruned - void PruneTokensForFrame(int32 frame, bool is_expand); + void PruneTokensForFrame(int32 frame); // Go backwards through still-alive tokens, pruning them. note: cur_frame is // where hash toks_ are (so we do not want to mess with it because these tokens @@ -395,7 +326,7 @@ class Lattice2BiglmFasterDecoder { // Version of PruneActiveTokens that we call on the final frame. // Takes into account the final-prob of tokens. - void PruneActiveTokensFinal(int32 cur_frame, bool is_expand=false); + void PruneActiveTokensFinal(int32 cur_frame); /// Gets the weight cutoff. Also counts the active tokens. BaseFloat GetCutoff(Elem *list_head, size_t *tok_count, @@ -407,12 +338,8 @@ class Lattice2BiglmFasterDecoder { if (arc->olabel == 0) { return lm_state; // no change in LM state if no word crossed. } else { // Propagate in the LM-diff FST. - Timer timer; - propage_lm_num_++; - if (expanding_) propage_lm_expand_num_++; Arc lm_arc; bool ans = lm_diff_fst_->GetArc(lm_state, arc->olabel, &lm_arc); - propage_time_+=timer.Elapsed(); if (!ans) { // this case is unexpected for statistical LMs. if (!warned_noarc_) { warned_noarc_ = true; @@ -454,20 +381,8 @@ class Lattice2BiglmFasterDecoder { // The following variables are used to check the existing tokens and best // token in certain frame. It will build in function ExpandShadowTokens() // Each element in the vector corresponds to a frame(t). - // TODO: add comments: we only update toks_shadowing_ but not toks_backfill_hclg_ - typedef std::unordered_map, std::equal_to, - fst::PoolAllocator > > StateHash; - typedef std::unordered_map, std::equal_to, - fst::PoolAllocator > > PairHash; - PairHash toks_backfill_pair_[2]; - std::vector toks_backfill_hclg_; - typedef std::pair QElem; - std::queue expand_current_frame_queue_[2]; - std::queue& GetExpandQueue(int32 frame) { return expand_current_frame_queue_[frame%2]; } - PairHash& GetBackfillMap(int32 frame) { return toks_backfill_pair_[frame%2]; } - void InitDecoding(); + std::vector* > toks_backfill_pair_; + std::vector* > toks_backfill_hclg_; // temp variable used to process special case. The pair is (t, state_id). // As we want to process the token which has smaller t index at first, @@ -484,11 +399,6 @@ class Lattice2BiglmFasterDecoder { std::vector active_toks_; // Lists of tokens, indexed by - int32 ToksNum(int32 f) { - int32 c=0; - for (Token *t=active_toks_[f].toks; t; t=t->next) c++; - return c; - } // frame (members of TokenList are toks, must_prune_forward_links, // must_prune_tokens). std::vector queue_; // temp variable used in ProcessNonemitting, @@ -547,20 +457,26 @@ class Lattice2BiglmFasterDecoder { // Actually, we only build the two maps for each frame once. Otherwise, in // ExpandShadowTokens(), it will be increased. In PruneTokenForFrame(), it // will be decreased. - void BuildBackfillMap(int32 frame, int32 frame_stop_expand, bool clear=false); - void BuildHCLGMapFromHash(int32 frame, bool append=true); - Token *ExpandShadowTokensSub(StateId ilabel, - StateId new_hclg_state, StateId new_lm_state, int32 frame, - int32 new_frame_index, BaseFloat tot_cost, BaseFloat extra_cost, BaseFloat backward_cost, - bool is_last); + void BuildBackfillMap(int32 frame); + + // A recursive function. This can happen when LM histories merge, if a + // previously un-promising path became better. Before further exploration, + // propagate the change in cost forward through the lattice until it reaches + // the current frame, so that we can decode with up-to-date alphas. + void ProcessBetterExistingToken(int32 cur_frame, PairId new_pair_id, + BaseFloat new_tot_cost); + + // A recursive function. Propagate this state and its successors untill + // current frame. + void ProcessBetterHCLGToken(int32 cur_frame, Token *better_token); + + // Update the Backward cost of each token. Assume the current frame is the + // fake final frame. Iterator frame-1 to 0. For each token, the formula is + // tok->backward_cost = min(next_tok->backward_cost + link->graph + + // link->acoustic) + void UpdateBackwardCost(int32 cur_frame, BaseFloat delta); Vector cutoff_; - uint64 propage_lm_num_; - uint64 propage_lm_expand_num_; - bool expanding_; - double expand_time_; - double propage_time_; - double ta_, tb_; }; } // end namespace kaldi.