Hi, team!
I have encountered a problem with k2 in my code. Below is the description of this problem.
For a nnet_output with shape [B, T, D], I am trying to calculate the scores on a graph (MMI numerator or denominator) with any prefix segment of nnet_output, namely nnet_output[:, :t, :], where t is any index smaller than T (the total length in time axis).
Currently, I implement it by a loop. But this leads to much computation. My code is below
graph, _ = self.graph_compiler.compile(texts, self.P, replicate_den=True)
T = x.size()[1]
tot_scores = []
for t in range(T, 0, -1):
supervision = torch.Tensor([[0, 0, t]]).to(torch.int32) # [idx, start, length]
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision)
lats = k2.intersect_dense(graph, dense_fsa_vec, output_beam=10.0)
frame_score = lats.get_tot_scores(log_semiring=True, use_double_scores=True)
tot_scores.append(frame_score)
tot_scores = torch.cat(tot_scores)
Could these scores be calculated by parsing the lats obtained from the whole nnet_output, which means we can calculate them with only one k2.intersect_dense? Approximation is also ok for me.
Thanks for your help ! :)