Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
386 changes: 386 additions & 0 deletions egs/swbd/s5c/local/chain/tuning/run_tdnn_7n.sh

Large diffs are not rendered by default.

411 changes: 411 additions & 0 deletions egs/swbd/s5c/local/chain/tuning/run_tdnn_7o.sh

Large diffs are not rendered by default.

22 changes: 13 additions & 9 deletions src/chain/chain-supervision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,11 +650,11 @@ void AppendSupervision(const std::vector<const Supervision*> &input,
}
}

bool AddWeightToSupervisionFst(const fst::StdVectorFst &normalization_fst,
Supervision *supervision) {
// remove epsilons before composing. 'normalization_fst' has noepsilons so
bool AddWeightToFst(const fst::StdVectorFst &normalization_fst,
fst::StdVectorFst *supervision_fst) {
// remove epsilons before composing. 'normalization_fst' has noepsilons so
// the composed result will be epsilon free.
fst::StdVectorFst supervision_fst_noeps(supervision->fst);
fst::StdVectorFst supervision_fst_noeps(*supervision_fst);
fst::RmEpsilon(&supervision_fst_noeps);
if (!TryDeterminizeMinimize(kSupervisionMaxStates,
&supervision_fst_noeps))
Expand All @@ -673,15 +673,19 @@ bool AddWeightToSupervisionFst(const fst::StdVectorFst &normalization_fst,
if (!TryDeterminizeMinimize(kSupervisionMaxStates,
&composed_fst))
return false;
supervision->fst = composed_fst;

*supervision_fst = composed_fst;
// Make sure the states are numbered in increasing order of time.
SortBreadthFirstSearch(&(supervision->fst));
KALDI_ASSERT(supervision->fst.Properties(fst::kAcceptor, true) == fst::kAcceptor);
KALDI_ASSERT(supervision->fst.Properties(fst::kIEpsilons, true) == 0);
SortBreadthFirstSearch(supervision_fst);
KALDI_ASSERT(supervision_fst->Properties(fst::kAcceptor, true) == fst::kAcceptor);
KALDI_ASSERT(supervision_fst->Properties(fst::kIEpsilons, true) == 0);
return true;
}

bool AddWeightToSupervisionFst(const fst::StdVectorFst &normalization_fst,
Supervision *supervision) {
return AddWeightToFst(normalization_fst, &(supervision->fst));
}

void SplitIntoRanges(int32 num_frames,
int32 frames_per_range,
std::vector<int32> *range_starts) {
Expand Down
3 changes: 3 additions & 0 deletions src/chain/chain-supervision.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ class SupervisionSplitter {
/// This function also removes epsilons and makes sure supervision->fst has the
/// required sorting of states. Think of it as the final stage in preparation
/// of the supervision FST.
bool AddWeightToFst(const fst::StdVectorFst &normalization_fst,
fst::StdVectorFst *supervision_fst);

bool AddWeightToSupervisionFst(const fst::StdVectorFst &normalization_fst,
Supervision *supervision);

Expand Down
50 changes: 50 additions & 0 deletions src/chain/chain-training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,56 @@
namespace kaldi {
namespace chain {

void ComputeObjfAndDeriv2(const ChainTrainingOptions &opts,
const DenominatorGraph &den_graph,
const GeneralMatrix &supervision,
const CuMatrixBase<BaseFloat> &nnet_output,
int32 num_sequences, int32 frames_per_sequence,
BaseFloat *objf,
BaseFloat *l2_term,
BaseFloat *weight,
CuMatrixBase<BaseFloat> *nnet_output_deriv,
CuMatrixBase<BaseFloat> *xent_output_deriv) {
if (nnet_output_deriv) {
nnet_output_deriv->SetZero();
nnet_output_deriv->CopyFromMat(supervision.GetFullMatrix());
if (xent_output_deriv)
xent_output_deriv->CopyFromMat(*nnet_output_deriv);
} else if (xent_output_deriv) {
// this branch will be taken if xent_output_deriv but not
// nnet_output_deriv is set- which could happen if you want to compute the
// cross-entropy objective but not the derivatives.
xent_output_deriv->SetZero();
xent_output_deriv->CopyFromMat(supervision.GetFullMatrix());
}
BaseFloat sup_weight = 1.0;
DenominatorComputation denominator(opts, den_graph,
num_sequences,
nnet_output);
BaseFloat den_logprob = denominator.Forward();
bool ok = true;
if (nnet_output_deriv)
ok = denominator.Backward(-sup_weight, nnet_output_deriv);
// we don't consider log-prob w.r.t numerator.
*objf = -sup_weight * den_logprob;
*weight = sup_weight * num_sequences * frames_per_sequence;

if (!((*objf) - (*objf) == 0) || !ok) {
// inf or NaN detected, or denominator computation returned false.
if (nnet_output_deriv)
nnet_output_deriv->SetZero();
if (xent_output_deriv)
xent_output_deriv->SetZero();
BaseFloat default_objf = -10;
KALDI_WARN << "Objective function is " << (*objf)
<< " and denominator computation (if done) returned "
<< std::boolalpha << ok
<< ", setting objective function to " << default_objf
<< " per frame.";
*objf = default_objf * *weight;
}
}

void ComputeChainObjfAndDeriv(const ChainTrainingOptions &opts,
const DenominatorGraph &den_graph,
const Supervision &supervision,
Expand Down
18 changes: 15 additions & 3 deletions src/chain/chain-training.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct ChainTrainingOptions {

ChainTrainingOptions(): l2_regularize(0.0), leaky_hmm_coefficient(1.0e-05),
xent_regularize(0.0) { }

void Register(OptionsItf *opts) {
opts->Register("l2-regularize", &l2_regularize, "l2 regularization "
"constant for 'chain' training, applied to the output "
Expand Down Expand Up @@ -121,8 +121,20 @@ void ComputeChainObjfAndDeriv(const ChainTrainingOptions &opts,
BaseFloat *weight,
CuMatrixBase<BaseFloat> *nnet_output_deriv,
CuMatrixBase<BaseFloat> *xent_output_deriv = NULL);


/**
This function uses supervision as numerator and does denominator computation.
It can be uses, where numerator is fixed e.g. TS learning.
*/
void ComputeObjfAndDeriv2(const ChainTrainingOptions &opts,
const DenominatorGraph &den_graph,
const GeneralMatrix &supervision,
const CuMatrixBase<BaseFloat> &nnet_output,
int32 num_sequences, int32 frames_per_sequence,
BaseFloat *objf,
BaseFloat *l2_term,
BaseFloat *weight,
CuMatrixBase<BaseFloat> *nnet_output_deriv,
CuMatrixBase<BaseFloat> *xent_output_deriv = NULL);

} // namespace chain
} // namespace kaldi
Expand Down
2 changes: 1 addition & 1 deletion src/chainbin/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ LDFLAGS += $(CUDA_LDFLAGS)
LDLIBS += $(CUDA_LDLIBS)

BINFILES = chain-est-phone-lm chain-get-supervision chain-make-den-fst \
nnet3-chain-get-egs nnet3-chain-copy-egs nnet3-chain-merge-egs \
nnet3-chain-get-egs nnet3-chain-get-egs-post nnet3-chain-copy-egs nnet3-chain-merge-egs \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You forgot to add nnet3-chain-get-egs-post

nnet3-chain-shuffle-egs nnet3-chain-subset-egs \
nnet3-chain-acc-lda-stats nnet3-chain-train nnet3-chain-compute-prob \
nnet3-chain-combine nnet3-chain-normalize-egs
Expand Down
Loading