diff --git a/src/decoder/lattice-biglm-faster-decoder.h b/src/decoder/lattice-biglm-faster-decoder.h index 15abe4c0482..dc74d9a016a 100644 --- a/src/decoder/lattice-biglm-faster-decoder.h +++ b/src/decoder/lattice-biglm-faster-decoder.h @@ -29,6 +29,7 @@ #include "fstext/fstext-lib.h" #include "lat/kaldi-lattice.h" #include "decoder/lattice-faster-decoder.h" // for options. +#include "base/timer.h" namespace kaldi { @@ -71,6 +72,7 @@ class LatticeBiglmFasterDecoder { ~LatticeBiglmFasterDecoder() { DeleteElems(toks_.Clear()); ClearActiveTokens(); + KALDI_VLOG(1) << "time: " << expand_time_ << " " << propage_time_<< " " << ta_ << " " << tb_; } // Get Cutoff @@ -686,9 +688,11 @@ class LatticeBiglmFasterDecoder { if (arc->olabel == 0) { return lm_state; // no change in LM state if no word crossed. } else { // Propagate in the LM-diff FST. + Timer timer; propage_lm_num_++; Arc lm_arc; bool ans = lm_diff_fst_->GetArc(lm_state, arc->olabel, &lm_arc); + propage_time_+=timer.Elapsed(); if (!ans) { // this case is unexpected for statistical LMs. if (!warned_noarc_) { warned_noarc_ = true; @@ -708,6 +712,7 @@ class LatticeBiglmFasterDecoder { void ProcessEmitting(DecodableInterface *decodable, int32 frame) { // Processes emitting arcs for one frame. Propagates from prev_toks_ to cur_toks_. + Timer timer; Elem *last_toks = toks_.Clear(); // swapping prev_toks_ / cur_toks_ Elem *best_elem = NULL; BaseFloat adaptive_beam; @@ -783,9 +788,11 @@ class LatticeBiglmFasterDecoder { e_tail = e->tail; toks_.Delete(e); // delete Elem } + ta_+=timer.Elapsed(); } void ProcessNonemitting(int32 frame) { + Timer timer; // note: "frame" is the same as emitting states just processed. // Processes nonemitting arcs for one frame. Propagates within toks_. @@ -852,6 +859,7 @@ class LatticeBiglmFasterDecoder { } } // for all arcs } // while queue not empty + tb_+=timer.Elapsed(); } @@ -909,6 +917,9 @@ class LatticeBiglmFasterDecoder { KALDI_ASSERT(num_toks_ == 0); } uint64 propage_lm_num_; + double expand_time_; + double propage_time_; + double ta_, tb_; }; } // end namespace kaldi. diff --git a/src/decoder/lattice2-biglm-faster-decoder.cc b/src/decoder/lattice2-biglm-faster-decoder.cc index 9cd02b8c6d6..c2ff82865f0 100644 --- a/src/decoder/lattice2-biglm-faster-decoder.cc +++ b/src/decoder/lattice2-biglm-faster-decoder.cc @@ -32,8 +32,8 @@ Lattice2BiglmFasterDecoder::Lattice2BiglmFasterDecoder( lm_diff_fst->Start() != fst::kNoStateId); toks_.SetSize(1000); // just so on the first frame we do something reasonable. for (int i = 0; i < 2; i++) toks_shadowing_[i].SetSize(1000); // just so on the first frame we do something reasonable. - toks_backfill_hclg_.resize(0); - } + ClearHCLGMap(); +} bool Lattice2BiglmFasterDecoder::Decode(DecodableInterface *decodable) { @@ -53,15 +53,16 @@ bool Lattice2BiglmFasterDecoder::Decode(DecodableInterface *decodable) { if (frame % config_.prune_interval == 0) { PruneActiveTokens(frame, config_.lattice_beam * 0.1); // use larger delta. } - int32 t = frame-config_.prune_interval-config_.explore_interval; - if (t >= 0 && (frame-config_.explore_interval) % config_.prune_interval == 0) { - KALDI_ASSERT(t==last_expand_frame); - for (; t<=frame; t++) - ExpandShadowTokens(t, frame-config_.explore_interval-1, decodable, t==last_expand_frame); - last_expand_frame=frame-config_.explore_interval; + int32 t = frame - config_.prune_interval - config_.explore_interval; + if (t >= 0 && + (frame - config_.explore_interval) % config_.prune_interval == 0) { + KALDI_ASSERT(t == last_expand_frame); + for (; t <= frame; t++) { + ExpandShadowTokens(t, frame - config_.explore_interval - 1, decodable, + t == last_expand_frame); + } + last_expand_frame = frame - config_.explore_interval; } - - // We could add another config option to decide the gap between state passing // and lm passing. } @@ -95,33 +96,39 @@ bool Lattice2BiglmFasterDecoder::Decode(DecodableInterface *decodable, -void Lattice2BiglmFasterDecoder::ExpandShadowTokens(int32 cur_frame, int32 frame_stop_expand, DecodableInterface *decodable, bool first) { +void Lattice2BiglmFasterDecoder::ExpandShadowTokens(int32 cur_frame, + int32 frame_stop_expand, DecodableInterface *decodable, bool first) { Timer timer; - expanding_=true; - bool is_last = cur_frame <= frame_stop_expand; // the last time we do expand in this frame + expanding_ = true; + + bool is_last = cur_frame <= frame_stop_expand; // the last time we do expand + // in this frame KALDI_ASSERT(cur_frame >= 0); auto& cur_q = GetExpandQueue(cur_frame); auto& cur_h = GetBackfillMap(cur_frame); - if (cur_frame > frame_stop_expand && !cur_q.size()) { + + if (cur_frame > frame_stop_expand && cur_q.empty()) { expanding_=false; return; } + // if cur_frame is the first frame of this segment, build the pair map for it. + // Otherwise, it has been build by previous frame. if (first) BuildBackfillMap(cur_frame, frame_stop_expand, first); - if ( (cur_frame + 1) < active_toks_.size()) { + if ((cur_frame + 1) < active_toks_.size()) { BuildBackfillMap(cur_frame+1, frame_stop_expand, true); } while (!cur_q.empty()) { - auto q_elem= cur_q.front(); + auto q_elem = cur_q.front(); cur_q.pop(); Token* tok = q_elem.first; - tok->in_queue=false; + tok->in_queue = false; bool cur_better_hclg = q_elem.second; int32 frame = cur_frame; BaseFloat cur_cutoff = (frame+1 < cutoff_.Dim())? -cutoff_(frame+1) : std::numeric_limits::infinity(); + cutoff_(frame+1) : std::numeric_limits::infinity(); if (tok->tot_cost > cur_cutoff) { tok->shadowing_tok = NULL; // already expand @@ -131,33 +138,23 @@ cutoff_(frame+1) : std::numeric_limits::infinity(); ForwardLink *link=NULL, *links_to_clear=NULL; if (tok->shadowing_tok == NULL) { // if we need to update a shadowing token itself - link=tok->links; + link = tok->links; tok->links=NULL; // we firstly un-hook it from shadowing token - links_to_clear=link; // we will reconstruct it and delete the original one later + links_to_clear=link; // we will reconstruct it and delete the original + // one later } else { // otherwise, we are updating a shadowed token // we obtain template links from shadowing token // sanity check: - // KALDI_ASSERT(toks_backfill_hclg_[frame]->find(tok->hclg_state)->second==tok->shadowing_tok); + // KALDI_ASSERT(toks_backfill_hclg_[frame]->find( + // tok->hclg_state)->second == tok->shadowing_tok); KALDI_ASSERT(!tok->links); Token* shadowing_tok = tok; - while (shadowing_tok->shadowing_tok && !shadowing_tok->links) shadowing_tok = shadowing_tok->shadowing_tok; - // Update toks_shadowing_mod for better_hclg here - // Notice that we only update if it reaches NumFramesDecoded(), since it will affect explore - if (frame == NumFramesDecoded() && *tok > *shadowing_tok) { - HashList &toks_shadowing_mod=toks_shadowing_[frame%2]; - ElemShadow *elem = toks_shadowing_mod.Find(tok->hclg_state); - if (elem) { - if (*tok < *elem->val) { - elem->val = tok; - // sanity check - // KALDI_ASSERT(!(*toks_backfill_hclg_[frame]).find(tok->hclg_state)->second->shadowing_tok); - } - } else // from better_hclg - toks_shadowing_mod.Insert(tok->hclg_state, tok); - } + while (shadowing_tok->shadowing_tok && !shadowing_tok->links) + shadowing_tok = shadowing_tok->shadowing_tok; link = shadowing_tok->links; - if (!link) { + + if (!link) { // link == NULL // for the end of decoding, we need to expand all if (is_last) tok->shadowing_tok = NULL; @@ -167,13 +164,14 @@ cutoff_(frame+1) : std::numeric_limits::infinity(); if (*iter.second > *tok) // better_hclg tok->shadowing_tok = NULL; else - KALDI_ASSERT(tok->shadowing_tok == iter.second); + tok->shadowing_tok = iter.second; // TODO: check this one + //KALDI_ASSERT(tok->shadowing_tok == iter.second); } // for normal shadowed token is_last==false, we process it later continue; } } - if (cur_better_hclg && config_.better_hclg==2) { + if (cur_better_hclg && config_.better_hclg >= 2) { for (fst::ArcIterator > aiter(fst_, tok->hclg_state); !aiter.Done(); aiter.Next()) { @@ -197,7 +195,10 @@ cutoff_(frame+1) : std::numeric_limits::infinity(); if (new_frame_index+1 < cutoff_.Dim() && tot_cost > cutoff_(new_frame_index+1)) continue; if (extra_cost > config_.lattice_beam) continue; - Token* new_tok = ExpandShadowTokensSub(ilabel, new_hclg_state, new_lm_state, frame, new_frame_index, tot_cost, extra_cost, backward_cost, is_last); + Token* new_tok = ExpandShadowTokensSub(ilabel, new_hclg_state, + new_lm_state, frame, new_frame_index, tot_cost, extra_cost, + backward_cost, is_last); + if (!new_tok) continue; // create lattice arc tok->links = new ForwardLink(new_tok, arc.ilabel, arc.olabel, graph_cost, ac_cost, tok->links, @@ -217,15 +218,21 @@ cutoff_(frame+1) : std::numeric_limits::infinity(); // However, the way to deal with them is similar. for (; link != NULL; link = link->next) { Token *next_tok = link->next_tok; - while (next_tok->shadowing_tok && !next_tok->links) next_tok=next_tok->shadowing_tok; + // Find the exploration token for this hclg state. + while (next_tok->shadowing_tok && !next_tok->links) + next_tok=next_tok->shadowing_tok; + StateId ilabel = link->ilabel; - int32 new_frame_index = ilabel ? frame+1 : frame; - if (new_frame_indexlinks) continue; // this link should be pruned - + int32 new_frame_index = ilabel ? frame+1 : frame; + + // this link should be pruned + if (new_frame_index < NumFramesDecoded() && !next_tok->links) continue; + Arc arc(ilabel, link->olabel, link->graph_cost_ori, 0); BaseFloat graph_cost_ori = link->graph_cost_ori; // TODO StateId new_hclg_state = next_tok->hclg_state; - StateId new_lm_state = PropagateLm(tok->lm_state, &arc); // may affect "arc.weight". + StateId new_lm_state = PropagateLm(tok->lm_state, &arc); // may affect + // "arc.weight". BaseFloat ac_cost = link->acoustic_cost, graph_cost = arc.weight.Value(), cur_cost = tok->tot_cost, @@ -235,20 +242,28 @@ cutoff_(frame+1) : std::numeric_limits::infinity(); // "next_tok" which is the destation of "shadowing_token". So they are // estimated rather than exact. They will be used to initialize a new // token and help to decide the new token will be expanded or not - BaseFloat extra_cost = next_tok->extra_cost + tot_cost - next_tok->tot_cost, // inherit backward cost, use its own tot_cost + + // inherit backward cost, use its own tot_cost + BaseFloat extra_cost = next_tok->extra_cost + tot_cost - + next_tok->tot_cost, backward_cost = next_tok->backward_cost; // prepare to store a new token in the current / next frame - if (new_frame_index+1 < cutoff_.Dim() && - tot_cost > cutoff_(new_frame_index+1)) continue; + if (new_frame_index + 1 < cutoff_.Dim() && + tot_cost > cutoff_(new_frame_index + 1) ) continue; if (extra_cost > config_.lattice_beam) continue; - Token* new_tok = ExpandShadowTokensSub(ilabel, new_hclg_state, new_lm_state, frame, new_frame_index, tot_cost, extra_cost, backward_cost, is_last); + Token* new_tok = ExpandShadowTokensSub(ilabel, new_hclg_state, + new_lm_state, frame, new_frame_index, tot_cost, extra_cost, + backward_cost, is_last, + next_tok); + if (!new_tok) continue; // create lattice arc tok->links = new ForwardLink(new_tok, arc.ilabel, arc.olabel, graph_cost, ac_cost, tok->links, graph_cost_ori); } // end of for loop } + while (links_to_clear) { ForwardLink* l=links_to_clear->next; delete links_to_clear; @@ -258,10 +273,21 @@ cutoff_(frame+1) : std::numeric_limits::infinity(); } // Clean the backfill map - cur_h.clear(); if (is_last) toks_backfill_hclg_[cur_frame]->clear(); + if (cur_frame == NumFramesDecoded()) { + for(Token *cur_tok = active_toks_[cur_frame].toks; cur_tok != NULL; cur_tok = cur_tok->next) { + PairId pair = ConstructPair(cur_tok->hclg_state, cur_tok->lm_state); + auto elem_main = toks_.Find(pair); + if (!elem_main) toks_.Insert(pair, cur_tok); + } + // sanity check + KALDI_ASSERT(ToksNum(cur_frame) == cur_h.size()); + } + cur_h.clear(); - KALDI_VLOG(2) << "expand fr num: " << cur_frame << " " << is_last << " " << GetExpandQueue(cur_frame+1).size() << " " << ToksNum(cur_frame); + KALDI_VLOG(2) << "expand fr num: " << cur_frame << " " << is_last << " " + << GetExpandQueue(cur_frame+1).size() << " " + << ToksNum(cur_frame); expanding_=false; expand_time_ += timer.Elapsed(); } @@ -602,9 +628,12 @@ void Lattice2BiglmFasterDecoder::PruneTokensForFrame(int32 frame, bool is_expand if (tok->extra_cost == std::numeric_limits::infinity()) { // token is unreachable from end of graph; (no forward links survived) // excise tok from list and delete tok. - if (toks_backfill_hclg_.size()>frame && frame >= NumFramesDecoded()-config_.prune_interval) { // the map has been built - if (toks_backfill_hclg_[frame]->erase(tok->hclg_state)) - ; //for (Token* t=toks; t; t=t->next) KALDI_ASSERT(t->shadowing_tok!=tok); // sanity check + if (toks_backfill_hclg_.size()>frame && + frame >= NumFramesDecoded()-config_.prune_interval-config_.explore_interval) { // the map has been built + auto r=toks_backfill_hclg_[frame]->find(tok->hclg_state); + // r==end() is from we didn't update toks_backfill_hclg_ in pne + if (r!=toks_backfill_hclg_[frame]->end() && r->second==tok) + toks_backfill_hclg_[frame]->erase(tok->hclg_state); } if (prev_tok != NULL) prev_tok->next = tok->next; else toks = tok->next; @@ -721,14 +750,16 @@ BaseFloat Lattice2BiglmFasterDecoder::GetCutoff(Elem *list_head, } } -Lattice2BiglmFasterDecoder::Token* Lattice2BiglmFasterDecoder::ExpandShadowTokensSub(StateId ilabel, - StateId new_hclg_state, StateId new_lm_state, int32 frame, - int32 new_frame_index, BaseFloat tot_cost, BaseFloat extra_cost, BaseFloat backward_cost, - bool is_last) { +Lattice2BiglmFasterDecoder::Token* Lattice2BiglmFasterDecoder::ExpandShadowTokensSub( + StateId ilabel, StateId new_hclg_state, StateId new_lm_state, int32 frame, + int32 new_frame_index, BaseFloat tot_cost, BaseFloat extra_cost, + BaseFloat backward_cost, bool is_last, + Token* shadowing_tok) { Token *&toks = ilabel ? active_toks_[frame+1].toks : active_toks_[frame].toks; assert(toks); - Token *tok_found=NULL; + // Existing token or not + Token *tok_found = NULL; PairId new_pair = ConstructPair(new_hclg_state, new_lm_state); auto& next_h = GetBackfillMap(new_frame_index); auto& next_q = GetExpandQueue(new_frame_index); @@ -737,7 +768,7 @@ Lattice2BiglmFasterDecoder::Token* Lattice2BiglmFasterDecoder::ExpandShadowToken tok_found = iter->second; Token* new_tok; - bool update_tok=false; + bool update_tok = false; if (!tok_found) { // A new token. // Construct the new token. new_tok = new Token(tot_cost, extra_cost, NULL, toks, new_lm_state, @@ -758,7 +789,7 @@ Lattice2BiglmFasterDecoder::Token* Lattice2BiglmFasterDecoder::ExpandShadowToken } } - bool better_hclg=false; + bool better_hclg = false; KALDI_ASSERT(toks_backfill_hclg_.size() > new_frame_index); auto iter_hclg = (*toks_backfill_hclg_[new_frame_index]).find(new_hclg_state); if (iter_hclg != (*toks_backfill_hclg_[new_frame_index]).end()) { @@ -766,7 +797,12 @@ Lattice2BiglmFasterDecoder::Token* Lattice2BiglmFasterDecoder::ExpandShadowToken better_hclg=true; // search: "Update toks_shadowing_mod for better_hclg" // although it is better hclg, we still keep its shadowing token for expanding in the next iter } else { - (*toks_backfill_hclg_[new_frame_index])[new_hclg_state] = new_tok; + // if shadowing_tok==NULL, this call is from cur_better_hclg=true; + // otherwise, we have a shadowing_tok for the current token + if (shadowing_tok) + (*toks_backfill_hclg_[new_frame_index])[new_hclg_state] = shadowing_tok; + else // we cannot garrentee new_tok is the best one in toks_backfill_hclg_ (it is the first one) + (*toks_backfill_hclg_[new_frame_index])[new_hclg_state] = new_tok; iter_hclg = (*toks_backfill_hclg_[new_frame_index]).find(new_hclg_state); } @@ -778,19 +814,22 @@ Lattice2BiglmFasterDecoder::Token* Lattice2BiglmFasterDecoder::ExpandShadowToken // search the comments above regarding to: // "we need to update a shadowing token itself" new_tok->shadowing_tok = NULL; - } + } else KALDI_ASSERT(new_tok->shadowing_tok->shadowing_tok != new_tok); + if (is_last || better_hclg || new_frame_index == frame) { if (new_tok->shadowing_tok) { // prepare for forwardlinks updating // sanity check - // KALDI_ASSERT(!new_tok->shadowing_tok->shadowing_tok || new_tok->shadowing_tok->shadowing_tok != new_tok); + // KALDI_ASSERT(!new_tok->shadowing_tok->shadowing_tok || + // new_tok->shadowing_tok->shadowing_tok != new_tok); new_tok->DeleteForwardLinks(); } - next_q.push(QElem(new_tok, better_hclg)); + next_q.push(QElem(new_tok, better_hclg || (config_.better_hclg==3 && update_tok))); new_tok->in_queue=true; } } return new_tok; } + void Lattice2BiglmFasterDecoder::ProcessEmitting(DecodableInterface *decodable, int32 frame) { Timer timer; @@ -800,17 +839,22 @@ void Lattice2BiglmFasterDecoder::ProcessEmitting(DecodableInterface *decodable, DeleteElemsShadow(toks_shadowing_mod); Elem *last_toks = toks_.Clear(); // swapping prev_toks_ / cur_toks_ + + // Get the best elem to estimate cutoff Elem *best_elem = NULL; BaseFloat adaptive_beam; size_t tok_cnt; BaseFloat cur_cutoff = GetCutoff(last_toks, &tok_cnt, &adaptive_beam, &best_elem); PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough. + // TODO PossiblyResizeHash for toks_shadowing_ - KALDI_VLOG(6) << "Adaptive beam on frame " << frame << "\t" << NumFramesDecoded() << " is " + KALDI_VLOG(6) << "Adaptive beam on frame " << frame << "\t" + << NumFramesDecoded() << " is " << adaptive_beam << "\t" << cur_cutoff; - if (cutoff_.Dim()<=frame) { - cutoff_.Resize(frame+1,kCopyData); - cutoff_.Data()[frame]=cur_cutoff; + + if (cutoff_.Dim() <= frame) { + cutoff_.Resize(frame + 1, kCopyData); + cutoff_.Data()[frame] = cur_cutoff; } @@ -852,12 +896,15 @@ void Lattice2BiglmFasterDecoder::ProcessEmitting(DecodableInterface *decodable, StateId state = PairToState(state_pair), lm_state = PairToLmState(state_pair); Token *tok = e->val; - if (tok->tot_cost < cur_cutoff) { + if (tok->tot_cost < cur_cutoff) { // cur_cutoff is used to limit prev_toks + ElemShadow *elem = toks_shadowing_check.Find(state); - assert(elem); - if (elem->val == tok || // explore + if (!elem || // better_hclg + elem->val == tok || // explore !tok->shadowing_tok || - *tok < *tok->shadowing_tok) { // it is generated by better_hclg; otherwise tok->shadowing_tok should be set by pne in the last frame + *tok < *tok->shadowing_tok) { // it is generated by better_hclg; + // otherwise tok->shadowing_tok should be + // set by pne in the last frame tok->shadowing_tok = NULL; for (fst::ArcIterator > aiter(fst_, state); !aiter.Done(); @@ -874,7 +921,9 @@ void Lattice2BiglmFasterDecoder::ProcessEmitting(DecodableInterface *decodable, if (tot_cost > next_cutoff) continue; else if (tot_cost + config_.beam < next_cutoff) next_cutoff = tot_cost + config_.beam; // prune by best current token - + + // If the tot_cost <= next_cutoff, it will be inserted into current + // token list PairId next_pair = ConstructPair(arc.nextstate, next_lm_state); Token *next_tok = FindOrAddToken(next_pair, frame, tot_cost, true, NULL); // true: emitting, NULL: no change indicator needed @@ -914,9 +963,13 @@ void Lattice2BiglmFasterDecoder::ProcessNonemitting(int32 frame) { // problem did not improve overall speed. KALDI_ASSERT(queue_.empty()); + // In ProcessNonemitting, only the epsilon arcs will be processed. So the + // toks_shadowing_check and toks_shadowing_mod point to the same slot. HashList &toks_shadowing_check=toks_shadowing_[frame%2]; HashList &toks_shadowing_mod=toks_shadowing_[frame%2]; + // Iterator the token list of current frame. Compute the cutoff and push + // the token into queue. BaseFloat best_cost = std::numeric_limits::infinity(); for (const Elem *e = toks_.GetList(); e != NULL; e = e->tail) { queue_.push_back(e->key); @@ -947,7 +1000,8 @@ void Lattice2BiglmFasterDecoder::ProcessNonemitting(int32 frame) { // of non-optimality (remember, this is the simple decoder), // but since most states are emitting it's not a huge issue. ElemShadow *elem = toks_shadowing_check.Find(state); - assert(elem); + assert(elem); // In exploration stage, we only process the best token for + // each hclg state. if (elem->val == tok) { // Explore the best token in certain HCLG state tok->DeleteForwardLinks(); // necessary when re-visiting tok->links = NULL; @@ -968,7 +1022,7 @@ void Lattice2BiglmFasterDecoder::ProcessNonemitting(int32 frame) { false, &changed); // false: non-emit ElemShadow *elem = toks_shadowing_mod.Find(arc.nextstate); if (elem) { - if ((*elem->val) > *new_tok) + if ((*elem->val) > *new_tok) // new token with better tot_cost elem->val = new_tok; // update it } else { toks_shadowing_mod.Insert(arc.nextstate, new_tok); @@ -993,24 +1047,27 @@ void Lattice2BiglmFasterDecoder::ProcessNonemitting(int32 frame) { ElemShadow *elem = toks_shadowing_mod.Find(cur_tok->hclg_state); assert(elem); if (cur_tok == elem->val){ - cur_tok->shadowing_tok = NULL; + cur_tok->shadowing_tok = NULL; // this token has been expanded and it is + // the best token for certain hclg state } else { - cur_tok->shadowing_tok = elem->val; + cur_tok->shadowing_tok = elem->val; // shadowed by best token // sanity check - KALDI_ASSERT(!cur_tok->shadowing_tok->shadowing_tok || cur_tok->shadowing_tok->shadowing_tok != cur_tok); + KALDI_ASSERT(!cur_tok->shadowing_tok->shadowing_tok || + cur_tok->shadowing_tok->shadowing_tok != cur_tok); cur_tok->extra_cost=std::numeric_limits::infinity(); - cur_tok->DeleteForwardLinks(); // since some tok could be shadowed after exploring in the same decoding step + cur_tok->DeleteForwardLinks(); // since some tok could be shadowed after + // exploring in the same decoding step } } - BuildHCLGMapFromHash(frame); // do it here to make it consistent - tb_+=timer.Elapsed(); + //tb_+=timer.Elapsed(); } void Lattice2BiglmFasterDecoder::BuildHCLGMapFromHash(int32 frame, bool append) { + // append indicates this is a new frame or not. if (!append) KALDI_ASSERT(toks_backfill_hclg_.size() > frame); + HashList &toks_shadowing_mod=toks_shadowing_[frame%2]; - StateHash *hclg_map = - new StateHash(); + StateHash *hclg_map = new StateHash(); hclg_map->reserve(toks_shadowing_mod.Size()); for (const ElemShadow *e = toks_shadowing_mod.GetList(); e != NULL; e = e->tail) { @@ -1030,6 +1087,7 @@ void Lattice2BiglmFasterDecoder::BuildHCLGMapFromHash(int32 frame, bool append) // //i.second->links is possible to be NULL since it is possible hasnt been pruned // } } + void Lattice2BiglmFasterDecoder::InitDecoding() { for (int i=0; i<2; i++) { toks_backfill_pair_[i].clear(); @@ -1042,7 +1100,7 @@ void Lattice2BiglmFasterDecoder::InitDecoding() { ClearActiveTokens(); cutoff_.Resize(1); - cutoff_.Data()[0] = std::numeric_limits::max(); + cutoff_.Data()[0] = std::numeric_limits::max(); // clean up private members warned_noarc_ = false; @@ -1052,10 +1110,11 @@ void Lattice2BiglmFasterDecoder::InitDecoding() { num_toks_ = 0; // At the beginning of an utterance, initialize. - toks_backfill_hclg_.resize(0); + ClearHCLGMap(); PairId start_pair = ConstructPair(fst_.Start(), lm_diff_fst_->Start()); active_toks_.resize(1); - Token *start_tok = new Token(0.0, 0.0, NULL, NULL, lm_diff_fst_->Start(), fst_.Start()); + Token *start_tok = new Token(0.0, 0.0, NULL, NULL, lm_diff_fst_->Start(), + fst_.Start()); active_toks_[0].toks = start_tok; toks_.Insert(start_pair, start_tok); toks_shadowing_[NumFramesDecoded()%2].Insert(fst_.Start(), start_tok); @@ -1064,34 +1123,62 @@ void Lattice2BiglmFasterDecoder::InitDecoding() { propage_lm_expand_num_=0; ProcessNonemitting(0); } -void Lattice2BiglmFasterDecoder::BuildBackfillMap(int32 frame, int32 frame_stop_expand, bool clear) { +void Lattice2BiglmFasterDecoder::BuildBackfillMap(int32 frame, + int32 frame_stop_expand, + bool clear) { + Timer timer; PairHash *pair_map = &GetBackfillMap(frame); std::queue& q = GetExpandQueue(frame); if (clear) pair_map->clear(); BaseFloat cur_cutoff = (frame+1 < cutoff_.Dim())? -cutoff_(frame+1) : std::numeric_limits::infinity(); + cutoff_(frame+1) : + std::numeric_limits::infinity(); + + StateHash *hclg_map = NULL; + if (frame >= toks_backfill_hclg_.size()) { + KALDI_ASSERT(frame == toks_backfill_hclg_.size()); + hclg_map = new StateHash(); + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + if (tok->tot_cost > cur_cutoff) continue; + auto r = hclg_map->insert({tok->hclg_state, tok}); + bool ok = r.second; + if (!ok && *r.first->second > *tok) // with this tok before + r.first->second = tok; + } + toks_backfill_hclg_.push_back(hclg_map); + } else hclg_map = toks_backfill_hclg_[frame]; - for(Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { + for (Token *tok = active_toks_[frame].toks; tok != NULL; tok = tok->next) { if (tok->tot_cost > cur_cutoff) { - tok->shadowing_tok=NULL; - tok->in_queue=false; + tok->shadowing_tok = NULL; // useless token, will not be expanded + tok->in_queue = false; // don't put it into the queue continue; } PairId cur_pair_id = ConstructPair(tok->hclg_state, tok->lm_state); + /* // TODO: we do not do it since we didn't update toks_backfill_hclg_ + Token* shadowing_tok = (*toks_backfill_hclg_[frame])[tok->hclg_state]; + if (tok==shadowing_tok) tok->shadowing_tok = NULL; + else tok->shadowing_tok = shadowing_tok; + */ + + // If this is a new pair id, insert it to pair_map bool ok = pair_map->insert({cur_pair_id, tok}).second; + // Build the expand queue. if (frame <= frame_stop_expand) { // need to expand if (ok) { // without this tok before if (tok->shadowing_tok) { q.push(QElem(tok, false)); tok->in_queue=true; } else tok->in_queue = false; - } else KALDI_ASSERT(tok->in_queue); // tok has been pushed by ExpandShadowTokens + } else KALDI_ASSERT(tok->in_queue); // tok has been pushed by + // ExpandShadowTokens } } + tb_+=timer.Elapsed(); } } diff --git a/src/decoder/lattice2-biglm-faster-decoder.h b/src/decoder/lattice2-biglm-faster-decoder.h index 26bc7388c07..54c17148b69 100644 --- a/src/decoder/lattice2-biglm-faster-decoder.h +++ b/src/decoder/lattice2-biglm-faster-decoder.h @@ -123,16 +123,18 @@ class Lattice2BiglmFasterDecoder { Lattice2BiglmFasterDecoderConfig GetOptions() { return config_; } + // Clean up backfill map + void ClearHCLGMap() { + for (auto e:toks_backfill_hclg_) delete e; + toks_backfill_hclg_.resize(0); + } // Releases the HashList and Backfill Maps which are created by // BuildBackfillMap() ~Lattice2BiglmFasterDecoder() { DeleteElems(toks_.Clear()); for (int i = 0; i < 2; i++) DeleteElemsShadow(toks_shadowing_[i]); ClearActiveTokens(); - // Clean up backfill map - for (int32 frame = NumFramesDecoded(); frame >= 0; frame--) { - delete toks_backfill_hclg_[frame]; - } + ClearHCLGMap(); KALDI_VLOG(1) << "time: " << expand_time_ << " " << propage_time_<< " " << ta_ << " " << tb_; } @@ -441,6 +443,9 @@ class Lattice2BiglmFasterDecoder { // more than one list (e.g. for current and previous frames), but only one of // them at a time can be indexed by StateId. HashList toks_; + // toks_shadowing_ is used in exploration stage. They record the best hclg + // token for each state (i.e. the key is StateId rather than PairId) on + // previous and current frame. HashList toks_shadowing_[2]; // When do expanding, we have two special cases need to be processed. @@ -465,7 +470,9 @@ class Lattice2BiglmFasterDecoder { std::vector toks_backfill_hclg_; typedef std::pair QElem; std::queue expand_current_frame_queue_[2]; - std::queue& GetExpandQueue(int32 frame) { return expand_current_frame_queue_[frame%2]; } + std::queue& GetExpandQueue(int32 frame) { + return expand_current_frame_queue_[frame%2]; + } PairHash& GetBackfillMap(int32 frame) { return toks_backfill_pair_[frame%2]; } void InitDecoding(); @@ -552,7 +559,8 @@ class Lattice2BiglmFasterDecoder { Token *ExpandShadowTokensSub(StateId ilabel, StateId new_hclg_state, StateId new_lm_state, int32 frame, int32 new_frame_index, BaseFloat tot_cost, BaseFloat extra_cost, BaseFloat backward_cost, - bool is_last); + bool is_last, + Token* shadowing_tok = NULL); Vector cutoff_; uint64 propage_lm_num_;